TensorRT-LLMs/cpp/tests/unit_tests/kernels/mlaChunkedPrefillTest.cu
jmydurant 7deefb3d2b
[TRTLLM-7192][feat] optimize MLA chunked prefill && support fp8 mla chunked prefill (#7477)
Signed-off-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
2025-09-15 21:43:49 +08:00

1076 lines
51 KiB
Plaintext

#include <algorithm>
#include <cstdint>
#include <gtest/gtest.h>
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/kernels/mlaChunkedPrefill.cuh"
#include "tensorrt_llm/runtime/cudaStream.h"
#include <cstring>
#include <iostream>
#include <memory>
#include <random>
// #define TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG
namespace
{
// kv_output {total_tokens, h=1, lora_size}
// k_pe_output {total_tokens, h=1, rope_size}
template <typename T, typename TCache>
void loadChunkedKVKernelRef(T* kv_output, T* k_pe_output, tensorrt_llm::kernels::KVBlockArray const& kv_cache,
int num_contexts, int64_t const* cu_ctx_chunked_len, int64_t const* chunked_ld_global_offset, int const lora_size,
int const rope_size, float const* kv_scale_quant_orig_ptr)
{
int const head_size = lora_size + rope_size;
float const kv_scale_quant_orig = kv_scale_quant_orig_ptr ? kv_scale_quant_orig_ptr[0] : 1.0f;
for (int b = 0; b < num_contexts; b++)
{
int const chunked_len = cu_ctx_chunked_len[b + 1] - cu_ctx_chunked_len[b];
for (int s = 0; s < chunked_len; s++)
{
int const local_token_idx = chunked_ld_global_offset[b] + s;
int const st_token_offset = (cu_ctx_chunked_len[b] + s);
auto const* kv_src = reinterpret_cast<TCache const*>(kv_cache.getKBlockPtr(b, local_token_idx));
for (int d = 0; d < head_size; d++)
{
auto kv_block_idx = kv_cache.getKVLocalIdx(local_token_idx, 0, head_size, d);
auto src_data = kv_src[kv_block_idx];
T data;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
data = T(float(src_data) * kv_scale_quant_orig);
}
else
{
data = src_data;
}
if (d < lora_size)
{
kv_output[st_token_offset * lora_size + d] = data;
}
else
{
k_pe_output[st_token_offset * rope_size + (d - lora_size)] = data;
}
}
}
}
}
// Q {total_q, H, D}
// KV {total_kv, 2, H, D}
// softmax_sum {total_q, H, 2} // {max/sum}
// output {total_q, H, D}
// total_q <= total_kv
template <typename T>
void selfAttentionRef(T* output, T* const Q, T* const KV, int batch_size, int num_heads, int64_t* const cu_seq_q_len,
int64_t* const cu_seq_kv_len, int head_size, bool return_softmax, float* softmax_sum, bool causal_mask)
{
for (int b = 0; b < batch_size; b++)
{
int curr_q_len = cu_seq_q_len[b + 1] - cu_seq_q_len[b];
int curr_kv_len = cu_seq_kv_len[b + 1] - cu_seq_kv_len[b];
int global_q_offset = cu_seq_q_len[b] * num_heads * head_size;
int global_kv_offset = cu_seq_kv_len[b] * 2 * num_heads * head_size;
int global_softmax_offset = cu_seq_q_len[b] * num_heads * 2;
if (curr_q_len == 0 || curr_kv_len == 0)
{
continue; // skip empty sequences
}
std::vector<float> P(curr_q_len * curr_kv_len);
for (int h = 0; h < num_heads; h++)
{
// BMM1
std::fill(P.begin(), P.end(), std::numeric_limits<double>::lowest());
T* const q_ptr = Q + global_q_offset + h * head_size;
T* const k_ptr = KV + global_kv_offset + h * head_size;
T* const v_ptr = k_ptr + num_heads * head_size;
T* output_ptr = output + global_q_offset + h * head_size;
for (int s_q = 0; s_q < curr_q_len; s_q++)
{
float softmax_max = std::numeric_limits<double>::lowest();
for (int s_kv = 0; s_kv < curr_kv_len; s_kv++)
{
// lower right mask
if (causal_mask && s_kv > curr_kv_len - curr_q_len + s_q)
{
break;
}
P[s_q * curr_kv_len + s_kv] = 0;
for (int d = 0; d < head_size; d++)
{
P[s_q * curr_kv_len + s_kv] += static_cast<float>(
q_ptr[s_q * num_heads * head_size + d] * k_ptr[s_kv * 2 * num_heads * head_size + d]);
}
if (softmax_max < P[s_q * curr_kv_len + s_kv])
{
softmax_max = P[s_q * curr_kv_len + s_kv];
}
}
for (int s_kv = 0; s_kv < curr_kv_len; s_kv++)
{
// lower right mask
if (causal_mask && s_kv > curr_kv_len - curr_q_len + s_q)
{
break;
}
P[s_q * curr_kv_len + s_kv] -= softmax_max;
}
if (return_softmax)
{
softmax_sum[global_softmax_offset + s_q * num_heads * 2 + h * 2] = softmax_max;
}
}
// softmax
for (int s_q = 0; s_q < curr_q_len; s_q++)
{
float sum = 0;
for (int s_kv = 0; s_kv < curr_kv_len; s_kv++)
{
// P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv]);
P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv]);
sum += P[s_q * curr_kv_len + s_kv];
}
for (int s_kv = 0; s_kv < curr_kv_len; s_kv++)
{
P[s_q * curr_kv_len + s_kv] /= sum;
}
if (return_softmax)
{
softmax_sum[global_softmax_offset + s_q * num_heads * 2 + h * 2 + 1] = sum;
}
}
// BMM2
for (int s_q = 0; s_q < curr_q_len; s_q++)
{
for (int d = 0; d < head_size; d++)
{
output_ptr[s_q * num_heads * head_size + d] = 0;
for (int s_kv = 0; s_kv < curr_kv_len; s_kv++)
{
output_ptr[s_q * num_heads * head_size + d] += static_cast<T>(P[s_q * curr_kv_len + s_kv]
* static_cast<float>(v_ptr[s_kv * 2 * num_heads * head_size + d]));
}
}
}
}
}
}
// chunked_KV {total_chunk_token, 2, H, D}
// KV {total_kv_token, 2, H, D}
template <typename T>
void copyRelatedChunkedKV(T* chunked_kv, T* const kv, int64_t const* chunked_ld_global_offset, int batch_size,
int num_heads, int64_t* const cu_kv_seq_len, int64_t* const cu_chunked_seq_len, int head_size)
{
for (int b = 0; b < batch_size; b++)
{
int src_global_offset = (cu_kv_seq_len[b] + chunked_ld_global_offset[b]) * 2 * num_heads * head_size;
int dst_global_offset = cu_chunked_seq_len[b] * 2 * num_heads * head_size;
int copy_length = cu_chunked_seq_len[b + 1] - cu_chunked_seq_len[b];
if (copy_length <= 0)
{
continue; // skip empty sequences
}
std::memcpy(chunked_kv + dst_global_offset, kv + src_global_offset,
copy_length * 2 * num_heads * head_size * sizeof(T));
}
}
// chunked_KV {total_chunk_token, 2, H, D}
// KV {total_kv_token, 2, H, D}
// It will copy the last chunk of KV cache to chunked_KV cache and calculate the cu_chunked_seq_len
template <typename T>
void copyFinalChunkedKV(T* chunked_kv, T* const kv, int chunk_size, int batch_size, int num_heads,
int64_t* const cu_kv_seq_len, int64_t* cu_chunked_seq_len, int head_size, int64_t* merge_op)
{
cu_chunked_seq_len[0] = 0;
for (int b = 0; b < batch_size; b++)
{
int curr_kv_len = cu_kv_seq_len[b + 1] - cu_kv_seq_len[b];
int last_chunk_size = curr_kv_len % chunk_size;
if (last_chunk_size == 0)
{
last_chunk_size = chunk_size; // ensure at least one chunk
}
if (last_chunk_size == curr_kv_len)
{
merge_op[b] = 2; // no need to merge, just copy
}
else
{
merge_op[b] = 1;
}
cu_chunked_seq_len[b + 1] = cu_chunked_seq_len[b] + last_chunk_size;
int global_token_offset = cu_kv_seq_len[b] + curr_kv_len - last_chunk_size;
int copy_length = last_chunk_size;
if (copy_length <= 0)
{
printf("copy_length is zero for batch %d, skipping...\n", b);
continue; // skip empty sequences
}
int src_global_offset = global_token_offset * 2 * num_heads * head_size;
int dst_global_offset = cu_chunked_seq_len[b] * 2 * num_heads * head_size;
std::memcpy(chunked_kv + dst_global_offset, kv + src_global_offset,
copy_length * 2 * num_heads * head_size * sizeof(T));
}
}
template <typename WeightType>
float getTolerance(float scale = 1.f)
{
float tol = 0.0;
if constexpr (std::is_same_v<WeightType, uint8_t>)
{
tol = 0.1;
}
else if constexpr (std::is_same_v<WeightType, float>)
{
tol = 0.001;
}
else if constexpr (std::is_same_v<WeightType, half>)
{
tol = 0.005;
}
else if constexpr (std::is_same_v<WeightType, __nv_bfloat16>)
{
tol = 0.05;
}
// Keep the scale in a sane range
return std::max(tol, scale * tol);
}
}; // namespace
template <typename Typepair>
class MlaChunkedPrefillTest : public ::testing::Test
{
protected:
using DataType = typename Typepair::first_type;
using TCache = typename Typepair::second_type;
static_assert(std::is_same_v<DataType, TCache> || std::is_same_v<TCache, __nv_fp8_e4m3>,
"TCache must be either the same type as DataType or __nv_fp8_e4m3");
std::shared_ptr<tensorrt_llm::runtime::CudaStream> mStream;
tensorrt_llm::runtime::BufferManager::ITensorPtr h_kv_cache_tensor{nullptr}, h_kv_cache_tensor_ref{nullptr},
d_kv_cache_tensor{nullptr}, h_compressed_kv_cache_tensor{nullptr}, d_compressed_kv_cache_tensor{nullptr},
h_compressed_offset_tensor{nullptr}, d_compressed_offset_tensor{nullptr}, h_cu_kv_seq_lens{nullptr},
d_cu_kv_seq_lens{nullptr}, h_cu_chunk_lens{nullptr}, d_cu_chunk_lens{nullptr},
h_chunked_ld_global_offset{nullptr}, d_chunked_ld_global_offset{nullptr}, h_cu_q_seq_lens{nullptr},
d_cu_q_seq_lens{nullptr},
// for kernel 1
h_compressed_kv_output{nullptr}, d_compressed_kv_output{nullptr}, h_k_pe_output{nullptr},
d_k_pe_output{nullptr}, h_compressed_kv_output_ref{nullptr}, h_k_pe_output_ref{nullptr},
h_kv_scale_quant_orig{nullptr}, d_kv_scale_quant_orig{nullptr},
// for merge attn {kv_full_tensor = kv + k_pe}
m_h_q_tensor{nullptr}, m_h_kv_full_tensor{nullptr}, m_h_chunked_kv_tensor{nullptr}, m_h_output_tensor{nullptr},
m_h_softmax_sum_tensor{nullptr}, m_h_softmax_sum_accum_tensor{nullptr}, m_h_output_tensor_ref{nullptr},
m_h_output_tensor_accum{nullptr}, m_d_q_tensor{nullptr}, m_d_kv_full_tensor{nullptr},
m_d_chunked_kv_tensor{nullptr}, m_d_output_tensor{nullptr}, m_d_softmax_sum_tensor{nullptr},
m_d_softmax_sum_accum_tensor{nullptr}, m_d_output_tensor_accum{nullptr}, m_h_merge_op{nullptr},
m_d_merge_op{nullptr};
int mBatchSize{};
int mMaxSeqLen{};
int mMaxQSeqLen{};
int mTotalQLen{};
int mTotalKVLen{};
int mChunkSize{};
int mNumHeads{};
int mLoraSize{};
int mRopeSize{};
int mNopeSize{};
int mMaxGenLength{};
// int mHeadSize{};
int mTokensPerBlock{};
int mMaxBlockPerSeq{};
bool mIsCausalMask{};
// for chunked main loop
std::vector<int> max_chunk_len_per_loop;
std::mt19937 gen;
void SetUp() override
{
if (shouldSkip())
{
GTEST_SKIP() << "Skipping mla chunked prefill test";
}
mStream = std::make_shared<tensorrt_llm::runtime::CudaStream>();
gen.seed(42U);
}
static bool shouldSkip()
{
return false;
}
void setDefaultParams()
{
mBatchSize = 16;
// mMaxSeqLen = 128;
mChunkSize = 16;
mNumHeads = 16;
mLoraSize = 512;
mRopeSize = 64;
mNopeSize = 128;
mIsCausalMask = false;
mMaxGenLength = 128;
mTokensPerBlock = 16;
assert(this->mChunkSize % this->mTokensPerBlock == 0);
}
void memsetZeroHost(tensorrt_llm::runtime::BufferManager::ITensorPtr& tensor)
{
void* ptr = tensor->data();
std::memset(ptr, 0, tensor->getSizeInBytes());
}
template <typename T>
void showHostTensor(tensorrt_llm::runtime::BufferManager::ITensorPtr& tensor)
{
auto* const ptr = reinterpret_cast<T*>(tensor->data());
for (int _ = 0; _ < tensor->getSize(); _++)
{
std::cout << static_cast<float>(ptr[_]) << " ";
}
std::cout << std::endl;
}
template <typename T>
void showHostTensor(tensorrt_llm::runtime::BufferManager::ITensorPtr& tensor, std::string const& tensor_name)
{
std::cout << "Tensor: " << tensor_name << ": \n";
auto* const ptr = reinterpret_cast<T*>(tensor->data());
for (int _ = 0; _ < tensor->getSize(); _++)
{
std::cout << static_cast<float>(ptr[_]) << " ";
}
std::cout << std::endl;
}
int generateRandomSizeSmallerThan(int a)
{
if (a <= 0)
{
return 0;
}
std::uniform_int_distribution<> distrib(0, a - 1);
// Generate and return the random number
return int{distrib(gen)};
}
float generateRandomFloat(float min, float max)
{
std::uniform_real_distribution<float> dist(min, max);
return dist(gen);
}
template <typename T>
void generateRandomData(T* data, int size)
{
for (int i = 0; i < size; i++)
{
data[i] = static_cast<T>(generateRandomFloat(-1.0f, 1.0f));
}
}
template <typename T>
void fillKVOffsetData(T* arr, size_t size, bool use_both_kv = true, int max_block_per_seq = 0)
{
if (use_both_kv)
{
for (int i = 0; i < size; i++)
{
arr[i] = static_cast<T>(i);
}
}
else
{
int temp_idx = 0;
for (int i = 0; i < size; i++)
{
bool is_v = (((i / max_block_per_seq) % 2) == 1);
if (is_v)
{
arr[i] = static_cast<T>(0);
}
else
{
arr[i] = static_cast<T>(temp_idx);
temp_idx++;
}
}
}
}
template <typename T>
void fillArrayDataWithMod(T* arr, size_t size)
{
for (int i = 0; i < size; i++)
{
arr[i] = static_cast<T>(i % 448);
}
}
bool allocateBuffers()
{
using tensorrt_llm::runtime::BufferManager;
using tensorrt_llm::runtime::CudaStream;
using tensorrt_llm::runtime::ITensor;
using tensorrt_llm::runtime::bufferCast;
auto dtype = nvinfer1::DataType::kHALF;
if constexpr (std::is_same_v<DataType, float>)
{
dtype = nvinfer1::DataType::kFLOAT;
}
else if constexpr (std::is_same_v<DataType, half>)
{
dtype = nvinfer1::DataType::kHALF;
}
else if constexpr (std::is_same_v<DataType, __nv_bfloat16>)
{
dtype = nvinfer1::DataType::kBF16;
}
else
{
return false;
}
auto cacheType = dtype;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
cacheType = nvinfer1::DataType::kFP8;
this->h_kv_scale_quant_orig
= tensorrt_llm::runtime::BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
this->d_kv_scale_quant_orig
= tensorrt_llm::runtime::BufferManager::gpuSync(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT);
auto* kv_scale_quant_orig_ptr = bufferCast<float>(*(this->h_kv_scale_quant_orig));
float kv_scale_orig_quant = 2.0F;
kv_scale_quant_orig_ptr[0] = 1.0 / kv_scale_orig_quant;
cudaMemcpy(this->d_kv_scale_quant_orig->data(), this->h_kv_scale_quant_orig->data(),
this->h_kv_scale_quant_orig->getSizeInBytes(), cudaMemcpyHostToDevice);
}
// cu lens
this->h_cu_kv_seq_lens = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize + 1}), nvinfer1::DataType::kINT64);
this->h_cu_q_seq_lens = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize + 1}), nvinfer1::DataType::kINT64);
this->d_cu_kv_seq_lens = tensorrt_llm::runtime::BufferManager::gpuSync(
this->h_cu_kv_seq_lens->getShape(), nvinfer1::DataType::kINT64);
this->d_cu_q_seq_lens = tensorrt_llm::runtime::BufferManager::gpuSync(
this->h_cu_q_seq_lens->getShape(), nvinfer1::DataType::kINT64);
{
this->mMaxSeqLen = 0;
this->mMaxQSeqLen = 0;
this->mTotalQLen = 0;
this->mTotalKVLen = 0;
// we only initialize cu_seq_lens
auto* cu_kv_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_kv_seq_lens));
auto* cu_q_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_q_seq_lens));
cu_kv_seq_lens_ptr[0] = 0;
cu_q_seq_lens_ptr[0] = 0;
for (int i = 0; i < this->mBatchSize; i++)
{
int temp_seq_len = this->generateRandomSizeSmallerThan(this->mMaxGenLength);
if (temp_seq_len == 0)
{
temp_seq_len = 1; // ensure at least one token
}
this->mMaxSeqLen = std::max(this->mMaxSeqLen, temp_seq_len);
cu_kv_seq_lens_ptr[i + 1] = cu_kv_seq_lens_ptr[i] + temp_seq_len;
auto temp_q_seq_len = temp_seq_len % this->mChunkSize;
if (temp_q_seq_len == 0)
{
temp_q_seq_len = this->mChunkSize; // ensure at least one chunk
}
cu_q_seq_lens_ptr[i + 1] = cu_q_seq_lens_ptr[i] + temp_q_seq_len;
this->mMaxQSeqLen = std::max(this->mMaxQSeqLen, temp_q_seq_len);
this->mTotalQLen += temp_q_seq_len;
this->mTotalKVLen += temp_seq_len;
}
cudaMemcpy(this->d_cu_kv_seq_lens->data(), this->h_cu_kv_seq_lens->data(),
this->h_cu_kv_seq_lens->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_cu_q_seq_lens->data(), this->h_cu_q_seq_lens->data(),
this->h_cu_q_seq_lens->getSizeInBytes(), cudaMemcpyHostToDevice);
#ifdef TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG
this->showHostTensor<int64_t>(this->h_cu_q_seq_lens, "cu_q_seq_lens");
this->showHostTensor<int64_t>(this->h_cu_kv_seq_lens, "cu_kv_seq_lens");
#endif
}
int const total_chunk_size = this->mChunkSize * this->mBatchSize;
int const total_cached_kv_len = this->mTotalKVLen - this->mTotalQLen;
int const chunked_loop_num = (total_cached_kv_len + total_chunk_size - 1) / total_chunk_size;
this->h_cu_chunk_lens = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({chunked_loop_num + 1, this->mBatchSize + 1}), nvinfer1::DataType::kINT64);
this->h_chunked_ld_global_offset = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({chunked_loop_num + 1, this->mBatchSize}), nvinfer1::DataType::kINT64);
this->memsetZeroHost(this->h_chunked_ld_global_offset);
this->d_cu_chunk_lens = tensorrt_llm::runtime::BufferManager::gpuSync(
this->h_cu_chunk_lens->getShape(), nvinfer1::DataType::kINT64);
this->d_chunked_ld_global_offset = tensorrt_llm::runtime::BufferManager::gpuSync(
this->h_chunked_ld_global_offset->getShape(), nvinfer1::DataType::kINT64);
// kv cache
this->mMaxBlockPerSeq = (this->mMaxSeqLen + this->mTokensPerBlock - 1) / this->mTokensPerBlock;
int maxChunkBlockPerSeq = (this->mChunkSize + this->mTokensPerBlock - 1) / this->mTokensPerBlock;
this->h_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize, 2, maxChunkBlockPerSeq, this->mNumHeads, this->mTokensPerBlock,
this->mNopeSize + this->mRopeSize}),
dtype);
this->h_kv_cache_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize, 2, maxChunkBlockPerSeq, this->mNumHeads, this->mTokensPerBlock,
this->mNopeSize + this->mRopeSize}),
dtype);
this->h_compressed_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize, 2, this->mMaxBlockPerSeq, this->mNumHeads, this->mTokensPerBlock,
this->mLoraSize + this->mRopeSize}),
cacheType);
this->h_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize, 2, this->mMaxBlockPerSeq + 1}), nvinfer1::DataType::kINT32);
this->d_kv_cache_tensor
= tensorrt_llm::runtime::BufferManager::gpuSync(this->h_kv_cache_tensor->getShape(), dtype);
this->d_compressed_kv_cache_tensor
= tensorrt_llm::runtime::BufferManager::gpuSync(this->h_compressed_kv_cache_tensor->getShape(), cacheType);
this->d_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
this->h_compressed_offset_tensor->getShape(), nvinfer1::DataType::kINT32);
{
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor));
auto* offset_ptr = bufferCast<int32_t>(*(this->h_compressed_offset_tensor));
this->memsetZeroHost(this->h_kv_cache_tensor);
this->memsetZeroHost(this->h_kv_cache_tensor_ref);
this->fillArrayDataWithMod(compressed_kv_cache_ptr, this->h_compressed_kv_cache_tensor->getSize());
this->fillKVOffsetData(
offset_ptr, this->h_compressed_offset_tensor->getSize(), false, this->mMaxBlockPerSeq);
cudaMemcpy(this->d_kv_cache_tensor->data(), this->h_kv_cache_tensor->data(),
this->h_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_compressed_kv_cache_tensor->data(), this->h_compressed_kv_cache_tensor->data(),
this->h_compressed_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_compressed_offset_tensor->data(), this->h_compressed_offset_tensor->data(),
this->h_compressed_offset_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
}
// tensor
// kv, k_pe for invokeMLALoadChunkedKV (kernel 1)
this->h_compressed_kv_output = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mLoraSize}), dtype);
this->h_k_pe_output = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype);
this->h_compressed_kv_output_ref = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mLoraSize}), dtype);
this->h_k_pe_output_ref = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype);
this->d_compressed_kv_output
= tensorrt_llm::runtime::BufferManager::gpuSync(this->h_compressed_kv_output->getShape(), dtype);
this->d_k_pe_output = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_k_pe_output->getShape(), dtype);
{
this->memsetZeroHost(this->h_compressed_kv_output);
this->memsetZeroHost(this->h_k_pe_output);
this->memsetZeroHost(this->h_compressed_kv_output_ref);
this->memsetZeroHost(this->h_k_pe_output_ref);
cudaMemcpy(this->d_compressed_kv_output->data(), this->h_compressed_kv_output->data(),
this->h_compressed_kv_output->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_k_pe_output->data(), this->h_k_pe_output->data(), this->h_k_pe_output->getSizeInBytes(),
cudaMemcpyHostToDevice);
}
// invokeMergeAttnWithSoftmax, we just ignore rope_size here for simplicity
this->m_h_q_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype);
this->m_h_kv_full_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalKVLen, 2, this->mNumHeads, this->mNopeSize}), dtype);
this->m_h_chunked_kv_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mBatchSize * this->mChunkSize, 2, this->mNumHeads, this->mNopeSize}), dtype);
this->m_h_output_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype);
this->m_h_softmax_sum_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({2, this->mTotalQLen, this->mNumHeads}), nvinfer1::DataType::kFLOAT);
this->m_h_softmax_sum_accum_tensor = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({2, this->mTotalQLen, this->mNumHeads}), nvinfer1::DataType::kFLOAT);
this->m_h_output_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype);
this->m_h_output_tensor_accum = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype);
this->m_h_merge_op = tensorrt_llm::runtime::BufferManager::pinned(
ITensor::makeShape({chunked_loop_num + 1, this->mBatchSize}), nvinfer1::DataType::kINT64);
this->m_d_q_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_q_tensor->getShape(), dtype);
this->m_d_kv_full_tensor
= tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_kv_full_tensor->getShape(), dtype);
this->m_d_chunked_kv_tensor
= tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_chunked_kv_tensor->getShape(), dtype);
this->m_d_output_tensor
= tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_output_tensor->getShape(), dtype);
this->m_d_softmax_sum_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
this->m_h_softmax_sum_tensor->getShape(), nvinfer1::DataType::kFLOAT);
this->m_d_softmax_sum_accum_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(
this->m_h_softmax_sum_accum_tensor->getShape(), nvinfer1::DataType::kFLOAT);
this->m_d_output_tensor_accum
= tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_output_tensor_accum->getShape(), dtype);
this->m_d_merge_op
= tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_merge_op->getShape(), nvinfer1::DataType::kINT64);
{
auto* q_ptr = bufferCast<DataType>(*(this->m_h_q_tensor));
auto* kv_ptr = bufferCast<DataType>(*(this->m_h_kv_full_tensor));
generateRandomData(q_ptr, m_h_q_tensor->getSize());
generateRandomData(kv_ptr, m_h_kv_full_tensor->getSize());
this->memsetZeroHost(m_h_chunked_kv_tensor);
this->memsetZeroHost(m_h_output_tensor);
this->memsetZeroHost(m_h_softmax_sum_tensor);
this->memsetZeroHost(m_h_softmax_sum_accum_tensor);
this->memsetZeroHost(m_h_output_tensor_ref);
this->memsetZeroHost(m_h_output_tensor_accum);
// Copy data to device
cudaMemcpyAsync(m_d_q_tensor->data(), m_h_q_tensor->data(), m_h_q_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice, mStream->get());
cudaMemcpyAsync(m_d_kv_full_tensor->data(), m_h_kv_full_tensor->data(),
m_h_kv_full_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get());
cudaMemcpyAsync(m_d_chunked_kv_tensor->data(), m_h_chunked_kv_tensor->data(),
m_h_chunked_kv_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get());
cudaMemcpyAsync(m_d_output_tensor->data(), m_h_output_tensor->data(), m_h_output_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice, mStream->get());
cudaMemcpyAsync(m_d_softmax_sum_tensor->data(), m_h_softmax_sum_tensor->data(),
m_h_softmax_sum_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get());
cudaMemcpyAsync(m_d_softmax_sum_accum_tensor->data(), m_h_softmax_sum_accum_tensor->data(),
m_h_softmax_sum_accum_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get());
cudaMemcpyAsync(m_d_output_tensor_accum->data(), m_h_output_tensor_accum->data(),
m_h_output_tensor_accum->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get());
cudaStreamSynchronize(mStream->get());
}
this->prepareChunkedPrefillMetaData();
return true;
}
int prepareChunkedPrefillMetaData()
{
using tensorrt_llm::runtime::bufferCast;
int const total_chunk_size = this->mChunkSize * this->mBatchSize;
int chunked_loop_num = (this->mTotalKVLen - this->mTotalQLen + total_chunk_size - 1) / total_chunk_size;
auto* h_merge_op = bufferCast<int64_t>(*(this->m_h_merge_op)); // {chunked_loop_num + 1, batch_size}
auto* h_cu_q_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_q_seq_lens)); // {batch_size + 1}
auto* h_cu_kv_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_kv_seq_lens)); // {batch_size + 1}
auto* h_cu_chunk_lens_ptr
= bufferCast<int64_t>(*(this->h_cu_chunk_lens)); // {chunked_loop_num + 1, batch_size + 1}
auto* h_chunked_ld_global_offset_ptr
= bufferCast<int64_t>(*(this->h_chunked_ld_global_offset)); // {chunked_loop_num + 1, batch_size}
this->max_chunk_len_per_loop.clear();
std::vector<int64_t> chunked_seq_len_vec((chunked_loop_num + 1) * (this->mBatchSize), 0);
// 0 -> chunked_loop_num -1
int remain_buffer_len = total_chunk_size;
int curr_loop_idx = 0;
int temp_max_chunk_len = 0;
#define chunked_seq_len(chunked_loop_idx, b_idx) chunked_seq_len_vec[(chunked_loop_idx) * (this->mBatchSize) + (b_idx)]
#define cu_chunked_seq_len(chunked_loop_idx, b_idx) \
h_cu_chunk_lens_ptr[(chunked_loop_idx) * (this->mBatchSize + 1) + (b_idx)]
#define chunked_ld_global_offset(chunked_loop_idx, b_idx) \
h_chunked_ld_global_offset_ptr[(chunked_loop_idx) * (this->mBatchSize) + (b_idx)]
for (int b = 0; b < this->mBatchSize; b++)
{
int temp_cached_kv_len = (h_cu_kv_seq_lens_ptr[b + 1] - h_cu_kv_seq_lens_ptr[b])
- (h_cu_q_seq_lens_ptr[b + 1] - h_cu_q_seq_lens_ptr[b]);
while (temp_cached_kv_len > 0)
{
auto used_buffer_len = std::min(remain_buffer_len, temp_cached_kv_len);
remain_buffer_len -= used_buffer_len;
temp_cached_kv_len -= used_buffer_len;
temp_max_chunk_len = std::max(temp_max_chunk_len, used_buffer_len);
chunked_seq_len(curr_loop_idx, b) = used_buffer_len;
chunked_ld_global_offset(curr_loop_idx + 1, b)
= chunked_ld_global_offset(curr_loop_idx, b) + used_buffer_len;
if (remain_buffer_len == 0)
{
this->max_chunk_len_per_loop.push_back(temp_max_chunk_len);
temp_max_chunk_len = 0;
remain_buffer_len = total_chunk_size;
curr_loop_idx++;
}
}
}
if (this->max_chunk_len_per_loop.size() < chunked_loop_num)
{
this->max_chunk_len_per_loop.push_back(temp_max_chunk_len);
}
assert(this->max_chunk_len_per_loop.size() == chunked_loop_num);
// for not cached part
for (int b = 0; b < this->mBatchSize; b++)
{
int uncached_len = (h_cu_q_seq_lens_ptr[b + 1] - h_cu_q_seq_lens_ptr[b]);
chunked_seq_len(chunked_loop_num, b) = uncached_len;
}
for (int loop_idx = 0; loop_idx < chunked_loop_num + 1; loop_idx++)
{
for (int b = 0; b < this->mBatchSize; b++)
{
cu_chunked_seq_len(loop_idx, b + 1) = cu_chunked_seq_len(loop_idx, b) + chunked_seq_len(loop_idx, b);
}
}
// merge op
for (int loop_idx = 0; loop_idx < chunked_loop_num; loop_idx++)
{
for (int b = 0; b < this->mBatchSize; b++)
{
if (chunked_seq_len(loop_idx, b) != 0 && (loop_idx == 0 || chunked_seq_len(loop_idx - 1, b) == 0))
{
h_merge_op[loop_idx * (this->mBatchSize) + b] = 2; // copy
}
else if (chunked_seq_len(loop_idx, b) != 0)
{
h_merge_op[loop_idx * (this->mBatchSize) + b] = 1; // merge
}
else
{
h_merge_op[loop_idx * (this->mBatchSize) + b] = 0; // skip
}
}
}
// for the last uncached part
for (int b = 0; b < this->mBatchSize; b++)
{
int temp_cached_kv_len = (h_cu_kv_seq_lens_ptr[b + 1] - h_cu_kv_seq_lens_ptr[b])
- (h_cu_q_seq_lens_ptr[b + 1] - h_cu_q_seq_lens_ptr[b]);
if (temp_cached_kv_len == 0)
{
h_merge_op[chunked_loop_num * (this->mBatchSize) + b] = 2; // copy
}
else
{
h_merge_op[chunked_loop_num * (this->mBatchSize) + b] = 1; // merge
}
chunked_ld_global_offset(chunked_loop_num, b) = temp_cached_kv_len;
}
#undef chunked_seq_len
#undef cu_chunked_seq_len
#undef chunked_ld_global_offset
// copy to device
cudaMemcpy(this->d_cu_chunk_lens->data(), this->h_cu_chunk_lens->data(),
this->h_cu_chunk_lens->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->d_chunked_ld_global_offset->data(), this->h_chunked_ld_global_offset->data(),
this->h_chunked_ld_global_offset->getSizeInBytes(), cudaMemcpyHostToDevice);
cudaMemcpy(this->m_d_merge_op->data(), this->m_h_merge_op->data(), this->m_h_merge_op->getSizeInBytes(),
cudaMemcpyHostToDevice);
#ifdef TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG
std::cout << "chunked_loop_num: " << chunked_loop_num << '\n';
this->showHostTensor<int64_t>(this->m_h_merge_op, "merge_op");
this->showHostTensor<int64_t>(this->h_chunked_ld_global_offset, "chunked_ld_global_offset");
this->showHostTensor<int64_t>(this->h_cu_chunk_lens, "cu_chunk_lens");
#endif
return chunked_loop_num;
}
void PerformNormalAttention()
{
using tensorrt_llm::runtime::bufferCast;
auto* q_ptr = bufferCast<DataType>(*(this->m_h_q_tensor));
auto* kv_ptr = bufferCast<DataType>(*(this->m_h_kv_full_tensor));
auto* output_ptr = bufferCast<DataType>(*(this->m_h_output_tensor_ref));
auto* cu_q_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_q_seq_lens));
auto* cu_kv_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_kv_seq_lens));
selfAttentionRef(output_ptr, q_ptr, kv_ptr, this->mBatchSize, this->mNumHeads, cu_q_seq_lens_ptr,
cu_kv_seq_lens_ptr, this->mNopeSize, false, nullptr, this->mIsCausalMask);
}
void PerformMergedAttention()
{
using tensorrt_llm::runtime::bufferCast;
auto* h_q_ptr = bufferCast<DataType>(*(this->m_h_q_tensor));
auto* h_kv_ptr = bufferCast<DataType>(*(this->m_h_kv_full_tensor));
auto* h_chunked_kv_ptr = bufferCast<DataType>(*(this->m_h_chunked_kv_tensor));
auto* h_output_ptr = bufferCast<DataType>(*(this->m_h_output_tensor));
auto* h_output_accum_ptr = bufferCast<DataType>(*(this->m_h_output_tensor_accum));
auto* h_softmax_sum_ptr = bufferCast<float>(*(this->m_h_softmax_sum_tensor));
auto* h_softmax_sum_accum_ptr = bufferCast<float>(*(this->m_h_softmax_sum_accum_tensor));
auto* h_cu_q_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_q_seq_lens));
auto* h_cu_kv_seq_lens_ptr = bufferCast<int64_t>(*(this->h_cu_kv_seq_lens));
auto* h_cu_chunk_lens_ptr = bufferCast<int64_t>(*(this->h_cu_chunk_lens));
auto* h_chunked_ld_global_offset_ptr = bufferCast<int64_t>(*(this->h_chunked_ld_global_offset));
auto* d_kv_ptr = bufferCast<DataType>(*(this->m_d_kv_full_tensor));
auto* d_chunked_kv_ptr = bufferCast<DataType>(*(this->m_d_chunked_kv_tensor));
auto* d_softmax_sum_ptr = bufferCast<float>(*(this->m_d_softmax_sum_tensor));
auto* d_softmax_sum_accum_ptr = bufferCast<float>(*(this->m_d_softmax_sum_accum_tensor));
auto* d_output_ptr = bufferCast<DataType>(*(this->m_d_output_tensor));
auto* d_output_accum_ptr = bufferCast<DataType>(*(this->m_d_output_tensor_accum));
auto* d_merge_op = bufferCast<int64_t>(*(this->m_d_merge_op));
auto* d_cu_q_seq_lens_ptr = bufferCast<int64_t>(*(this->d_cu_q_seq_lens));
int const total_chunk_size = this->mChunkSize * this->mBatchSize;
int chunked_loop_num = (this->mTotalKVLen - this->mTotalQLen + total_chunk_size - 1) / total_chunk_size;
// do not apply mask
for (int _ = 0; _ < chunked_loop_num; _++)
{
// copy related kv chunk data
copyRelatedChunkedKV(h_chunked_kv_ptr, h_kv_ptr, h_chunked_ld_global_offset_ptr, this->mBatchSize,
this->mNumHeads, h_cu_kv_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize);
// attention
selfAttentionRef<DataType>(h_output_ptr, h_q_ptr, h_chunked_kv_ptr, this->mBatchSize, this->mNumHeads,
h_cu_q_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize, true, h_softmax_sum_ptr, false);
// merge attention
// copy curr_attn and softmax_sum to device
cudaMemcpy(d_softmax_sum_ptr, h_softmax_sum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice);
cudaMemcpy(d_output_ptr, h_output_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
// merge softmax
tensorrt_llm::kernels::invokeMergeAttnWithSoftmax<DataType>(d_output_accum_ptr, d_softmax_sum_accum_ptr,
d_output_accum_ptr, d_softmax_sum_accum_ptr, d_output_ptr, d_softmax_sum_ptr, this->mBatchSize,
d_cu_q_seq_lens_ptr, this->mMaxQSeqLen, d_merge_op, this->mNumHeads, this->mNopeSize, mStream->get());
cudaStreamSynchronize(mStream->get());
// copy merged softmax sum back to host
cudaMemcpy(h_softmax_sum_accum_ptr, d_softmax_sum_accum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_output_accum_ptr, d_output_accum_ptr, this->m_h_output_tensor->getSizeInBytes(),
cudaMemcpyDeviceToHost);
// update merge op, ld global offset, cu chunk lens ptr.
d_merge_op += this->mBatchSize;
h_cu_chunk_lens_ptr += (this->mBatchSize + 1);
h_chunked_ld_global_offset_ptr += this->mBatchSize;
}
// final round, apply causal mask.
// copy the last chunked kv data
copyRelatedChunkedKV(h_chunked_kv_ptr, h_kv_ptr, h_chunked_ld_global_offset_ptr, this->mBatchSize,
this->mNumHeads, h_cu_kv_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize);
// attention
selfAttentionRef<DataType>(h_output_ptr, h_q_ptr, h_chunked_kv_ptr, this->mBatchSize, this->mNumHeads,
h_cu_q_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize, true, h_softmax_sum_ptr, this->mIsCausalMask);
// merge attention
// copy curr_attn and softmax_sum to device
cudaMemcpy(d_softmax_sum_ptr, h_softmax_sum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(),
cudaMemcpyHostToDevice);
cudaMemcpy(d_output_ptr, h_output_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyHostToDevice);
tensorrt_llm::kernels::invokeMergeAttnWithSoftmax<DataType>(d_output_accum_ptr, d_softmax_sum_accum_ptr,
d_output_accum_ptr, d_softmax_sum_accum_ptr, d_output_ptr, d_softmax_sum_ptr, this->mBatchSize,
d_cu_q_seq_lens_ptr, this->mMaxQSeqLen, d_merge_op, this->mNumHeads, this->mNopeSize, mStream->get());
cudaStreamSynchronize(mStream->get());
// copy merged softmax sum back to host
cudaMemcpy(h_softmax_sum_accum_ptr, d_softmax_sum_accum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(),
cudaMemcpyDeviceToHost);
cudaMemcpy(
h_output_accum_ptr, d_output_accum_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost);
sync_check_cuda_error(mStream->get());
}
void PerformLoadChunkedKVRef(int chunk_idx)
{
using tensorrt_llm::runtime::bufferCast;
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output_ref));
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->h_k_pe_output_ref));
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->h_compressed_kv_cache_tensor));
auto* offset_ptr = bufferCast<int32_t>(*(this->h_compressed_offset_tensor));
auto* h_cu_chunk_lens_ptr = bufferCast<int64_t>(*(this->h_cu_chunk_lens)) + chunk_idx * (this->mBatchSize + 1);
auto* h_chunked_ld_global_offset_ptr
= bufferCast<int64_t>(*(this->h_chunked_ld_global_offset)) + chunk_idx * this->mBatchSize;
float* kv_scale_quant_orig_ptr = nullptr;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
kv_scale_quant_orig_ptr = bufferCast<float>(*(this->h_kv_scale_quant_orig));
}
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mBatchSize, this->mMaxBlockPerSeq, this->mTokensPerBlock,
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
loadChunkedKVKernelRef<DataType, TCache>(compressed_kv_output_ptr, k_pe_output_ptr, kv_cache, this->mBatchSize,
h_cu_chunk_lens_ptr, h_chunked_ld_global_offset_ptr, this->mLoraSize, this->mRopeSize,
kv_scale_quant_orig_ptr);
}
void PreformLoadChunkedKV(int chunk_idx)
{
using tensorrt_llm::runtime::bufferCast;
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->d_compressed_kv_output));
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->d_k_pe_output));
auto* compressed_kv_cache_ptr = bufferCast<TCache>(*(this->d_compressed_kv_cache_tensor));
auto* offset_ptr = bufferCast<int32_t>(*(this->d_compressed_offset_tensor));
auto* d_cu_chunk_lens_ptr = bufferCast<int64_t>(*(this->d_cu_chunk_lens)) + chunk_idx * (this->mBatchSize + 1);
auto* d_chunked_ld_global_offset_ptr
= bufferCast<int64_t>(*(this->d_chunked_ld_global_offset)) + chunk_idx * this->mBatchSize;
float* kv_scale_quant_orig_ptr = nullptr;
if constexpr (std::is_same_v<TCache, __nv_fp8_e4m3>)
{
kv_scale_quant_orig_ptr = bufferCast<float>(*(this->d_kv_scale_quant_orig));
}
tensorrt_llm::kernels::KVBlockArray kv_cache(this->mBatchSize, this->mMaxBlockPerSeq, this->mTokensPerBlock,
sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr,
reinterpret_cast<tensorrt_llm::kernels::KVBlockArrayForContextFMHA::DataType*>(offset_ptr));
// copy cu chunk lens to device
cudaMemcpy(this->d_cu_chunk_lens->data(), this->h_cu_chunk_lens->data(),
this->h_cu_chunk_lens->getSizeInBytes(), cudaMemcpyHostToDevice);
tensorrt_llm::kernels::invokeMLALoadChunkedKV<DataType, TCache>(compressed_kv_output_ptr, k_pe_output_ptr,
kv_cache, this->mBatchSize, d_cu_chunk_lens_ptr, d_chunked_ld_global_offset_ptr, this->mLoraSize,
this->mRopeSize, this->max_chunk_len_per_loop[chunk_idx], kv_scale_quant_orig_ptr, mStream->get());
cudaStreamSynchronize(this->mStream->get());
// copy result back to host
cudaMemcpy(this->h_compressed_kv_output->data(), compressed_kv_output_ptr,
this->h_compressed_kv_output->getSizeInBytes(), cudaMemcpyDeviceToHost);
cudaMemcpy(this->h_k_pe_output->data(), k_pe_output_ptr, this->h_k_pe_output->getSizeInBytes(),
cudaMemcpyDeviceToHost);
sync_check_cuda_error(this->mStream->get());
}
};
using MLATypes
= ::testing::Types<std::pair<half, half>, std::pair<__nv_bfloat16, __nv_bfloat16>, std::pair<float, float>,
std::pair<half, __nv_fp8_e4m3>, std::pair<__nv_bfloat16, __nv_fp8_e4m3>, std::pair<float, __nv_fp8_e4m3>>;
TYPED_TEST_SUITE(MlaChunkedPrefillTest, MLATypes);
TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedPrefillDefault)
{
using tensorrt_llm::runtime::bufferCast;
using DataType = typename TestFixture::DataType;
using TCache = typename TestFixture::TCache;
if constexpr (std::is_same_v<DataType, TCache>)
{
this->setDefaultParams();
this->allocateBuffers();
sync_check_cuda_error(this->mStream->get());
bool allEqual{true};
this->PerformNormalAttention();
sync_check_cuda_error(this->mStream->get());
this->PerformMergedAttention();
sync_check_cuda_error(this->mStream->get());
// check result
auto* output_ptr = bufferCast<DataType>(*(this->m_h_output_tensor_accum));
auto* output_ref_ptr = bufferCast<DataType>(*(this->m_h_output_tensor_ref));
for (int i = 0; i < this->m_h_output_tensor->getSize(); i++)
{
if (std::abs(static_cast<float>(output_ptr[i]) - static_cast<float>(output_ref_ptr[i]))
> getTolerance<DataType>(output_ptr[i]))
{
std::cout << "Output mismatch at index " << i << ": "
<< "expected " << static_cast<float>(output_ref_ptr[i]) << ", got "
<< static_cast<float>(output_ptr[i]) << std::endl;
allEqual = false;
break;
}
}
ASSERT_TRUE(allEqual);
}
}
TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedPrefillCausalMask)
{
using tensorrt_llm::runtime::bufferCast;
using DataType = typename TestFixture::DataType;
using TCache = typename TestFixture::TCache;
if constexpr (std::is_same_v<DataType, TCache>)
{
this->setDefaultParams();
this->mIsCausalMask = true;
this->allocateBuffers();
sync_check_cuda_error(this->mStream->get());
bool allEqual{true};
this->PerformNormalAttention();
sync_check_cuda_error(this->mStream->get());
this->PerformMergedAttention();
sync_check_cuda_error(this->mStream->get());
// check result
auto* output_ptr = bufferCast<DataType>(*(this->m_h_output_tensor_accum));
auto* output_ref_ptr = bufferCast<DataType>(*(this->m_h_output_tensor_ref));
for (int i = 0; i < this->m_h_output_tensor->getSize(); i++)
{
if (std::abs(static_cast<float>(output_ptr[i]) - static_cast<float>(output_ref_ptr[i]))
> getTolerance<DataType>(output_ptr[i]))
{
std::cout << "Output mismatch at index " << i << ": "
<< "expected " << static_cast<float>(output_ref_ptr[i]) << ", got "
<< static_cast<float>(output_ptr[i]) << std::endl;
allEqual = false;
break;
}
}
ASSERT_TRUE(allEqual);
}
}
TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedLoad)
{
using tensorrt_llm::runtime::bufferCast;
using DataType = typename TestFixture::DataType;
this->setDefaultParams();
this->allocateBuffers();
sync_check_cuda_error(this->mStream->get());
bool allEqual{true};
int const total_chunk_size = this->mChunkSize * this->mBatchSize;
int const chunked_loop_num = (this->mTotalKVLen - this->mTotalQLen + total_chunk_size - 1) / total_chunk_size;
for (int _ = 0; _ < chunked_loop_num - 1; _++)
{
this->PerformLoadChunkedKVRef(_);
sync_check_cuda_error(this->mStream->get());
this->PreformLoadChunkedKV(_);
sync_check_cuda_error(this->mStream->get());
// check result
auto* compressed_kv_output_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output_ref));
auto* compressed_kv_output_ref_ptr = bufferCast<DataType>(*(this->h_compressed_kv_output));
auto* k_pe_output_ptr = bufferCast<DataType>(*(this->h_k_pe_output));
auto* k_pe_output_ref_ptr = bufferCast<DataType>(*(this->h_k_pe_output_ref));
// check kv
for (int i = 0; i < this->h_compressed_kv_output->getSize(); i++)
{
if (std::abs(static_cast<float>(compressed_kv_output_ptr[i])
- static_cast<float>(compressed_kv_output_ref_ptr[i]))
> getTolerance<DataType>(compressed_kv_output_ptr[i]))
{
std::cout << "Compressed KV output mismatch at loop: " << _ << " index " << i << ": "
<< "expected " << static_cast<float>(compressed_kv_output_ref_ptr[i]) << ", got "
<< static_cast<float>(compressed_kv_output_ptr[i]) << std::endl;
allEqual = false;
break;
}
}
// check k_pe
for (int i = 0; i < this->h_k_pe_output->getSize(); i++)
{
if (std::abs(static_cast<float>(k_pe_output_ptr[i]) - static_cast<float>(k_pe_output_ref_ptr[i]))
> getTolerance<DataType>(k_pe_output_ptr[i]))
{
std::cout << "kpe mismatch at loop: " << _ << " index " << i << ": "
<< "expected " << static_cast<float>(k_pe_output_ref_ptr[i]) << ", got "
<< static_cast<float>(k_pe_output_ptr[i]) << std::endl;
allEqual = false;
break;
}
}
}
ASSERT_TRUE(allEqual);
}