#include #include #include #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/kvCacheUtils.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/kernels/mlaChunkedPrefill.cuh" #include "tensorrt_llm/runtime/cudaStream.h" #include #include #include #include // #define TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG namespace { // kv_output {total_tokens, h=1, lora_size} // k_pe_output {total_tokens, h=1, rope_size} template void loadChunkedKVKernelRef(T* kv_output, T* k_pe_output, tensorrt_llm::kernels::KVBlockArray const& kv_cache, int num_contexts, int64_t const* cu_ctx_chunked_len, int const lora_size, int const rope_size, int const chunk_size, int const chunk_idx, float const* kv_scale_quant_orig_ptr) { int const head_size = lora_size + rope_size; float const kv_scale_quant_orig = kv_scale_quant_orig_ptr ? kv_scale_quant_orig_ptr[0] : 1.0f; for (int b = 0; b < num_contexts; b++) { int const chunked_len = cu_ctx_chunked_len[b + 1] - cu_ctx_chunked_len[b]; for (int s = 0; s < chunked_len; s++) { int const local_token_idx = chunk_idx * chunk_size + s; int const ld_token_offset = (cu_ctx_chunked_len[b] + s); auto const* kv_src = reinterpret_cast(kv_cache.getKBlockPtr(b, local_token_idx)); for (int d = 0; d < head_size; d++) { auto kv_block_idx = kv_cache.getKVLocalIdx(local_token_idx, 0, head_size, d); auto src_data = kv_src[kv_block_idx]; T data; if constexpr (std::is_same_v) { data = T(float(src_data) * kv_scale_quant_orig); } else { data = src_data; } if (d < lora_size) { kv_output[ld_token_offset * lora_size + d] = data; } else { k_pe_output[ld_token_offset * rope_size + (d - lora_size)] = data; } } } } } // kv {total_tokens, 2, h, nope_size} // k_pe {total_tokens, h=1, rope_size} // output {b, 2, ceil(max_seq / cache_tokens_per_block), h, cache_tokens_per_block, (nope_size + rope_size)} // max_seq <= chunk_size template void setChunkedKVCacheForMLAKernelRef(T* output, T* kv_ptr, T* k_pe_ptr, int num_contexts, int64_t const* cu_seq_len, int const max_input_seq_len, int num_heads, int nope_size, int rope_size, int cache_tokens_per_block) { int head_size = nope_size + rope_size; int const kv_cache_size_per_block = num_heads * cache_tokens_per_block * head_size; int const kv_cache_block_num_per_seq = (max_input_seq_len + cache_tokens_per_block - 1) / cache_tokens_per_block; for (int b = 0; b < num_contexts; b++) { int const global_token_offset = cu_seq_len[b]; int const current_seq_len = cu_seq_len[b + 1] - cu_seq_len[b]; for (int s = 0; s < current_seq_len; s++) { int const global_token_idx = global_token_offset + s; int const kv_cache_block_offset_for_k = (b * 2 * kv_cache_block_num_per_seq + s / cache_tokens_per_block) * kv_cache_size_per_block; int const kv_cache_block_offset_for_v = kv_cache_block_offset_for_k + (kv_cache_block_num_per_seq * kv_cache_size_per_block); for (int h = 0; h < num_heads; h++) { int const ld_k_head_offset = (global_token_idx * 2 * num_heads * nope_size) + h * nope_size; int const ld_v_head_offset = ld_k_head_offset + num_heads * nope_size; int const ld_k_pe_head_offset = global_token_idx * rope_size; // copy kv for (int d = 0; d < nope_size; d++) { int const ld_k_idx = ld_k_head_offset + d; int const ld_v_idx = ld_v_head_offset + d; int const st_k_idx = kv_cache_block_offset_for_k + h * cache_tokens_per_block * head_size + (s % cache_tokens_per_block) * head_size + d; int const st_v_idx = kv_cache_block_offset_for_v + h * cache_tokens_per_block * head_size + (s % cache_tokens_per_block) * head_size + d; output[st_k_idx] = kv_ptr[ld_k_idx]; output[st_v_idx] = kv_ptr[ld_v_idx]; } // copy k_pe for (int d = 0; d < rope_size; d++) { int const ld_k_pe_idx = ld_k_pe_head_offset + d; int const st_k_pe_idx = kv_cache_block_offset_for_k + h * cache_tokens_per_block * head_size + (s % cache_tokens_per_block) * head_size + (nope_size + d); output[st_k_pe_idx] = k_pe_ptr[ld_k_pe_idx]; } } } } } // Q {total_q, H, D} // KV {total_kv, 2, H, D} // softmax_sum {total_q, H, 2} // {max/sum} // output {total_q, H, D} // total_q <= total_kv template void selfAttentionRef(T* output, T* const Q, T* const KV, int batch_size, int num_heads, int64_t* const cu_seq_q_len, int64_t* const cu_seq_kv_len, int head_size, bool return_softmax, float* softmax_sum, bool causal_mask) { for (int b = 0; b < batch_size; b++) { int curr_q_len = cu_seq_q_len[b + 1] - cu_seq_q_len[b]; int curr_kv_len = cu_seq_kv_len[b + 1] - cu_seq_kv_len[b]; int global_q_offset = cu_seq_q_len[b] * num_heads * head_size; int global_kv_offset = cu_seq_kv_len[b] * 2 * num_heads * head_size; int global_softmax_offset = cu_seq_q_len[b] * num_heads * 2; float bmm1_scale = 1.F / std::sqrt(static_cast(head_size)); if (curr_q_len == 0 || curr_kv_len == 0) { continue; // skip empty sequences } std::vector P(curr_q_len * curr_kv_len); for (int h = 0; h < num_heads; h++) { // BMM1 std::fill(P.begin(), P.end(), std::numeric_limits::lowest()); T* const q_ptr = Q + global_q_offset + h * head_size; T* const k_ptr = KV + global_kv_offset + h * head_size; T* const v_ptr = k_ptr + num_heads * head_size; T* output_ptr = output + global_q_offset + h * head_size; for (int s_q = 0; s_q < curr_q_len; s_q++) { float softmax_max = std::numeric_limits::lowest(); for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) { // lower right mask if (causal_mask && s_kv > curr_kv_len - curr_q_len + s_q) { break; } P[s_q * curr_kv_len + s_kv] = 0; for (int d = 0; d < head_size; d++) { P[s_q * curr_kv_len + s_kv] += static_cast( q_ptr[s_q * num_heads * head_size + d] * k_ptr[s_kv * 2 * num_heads * head_size + d]); } if (softmax_max < P[s_q * curr_kv_len + s_kv]) { softmax_max = P[s_q * curr_kv_len + s_kv]; } } for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) { // lower right mask if (causal_mask && s_kv > curr_kv_len - curr_q_len + s_q) { break; } P[s_q * curr_kv_len + s_kv] -= softmax_max; } if (return_softmax) { softmax_sum[global_softmax_offset + s_q * num_heads * 2 + h * 2] = softmax_max; } } // softmax for (int s_q = 0; s_q < curr_q_len; s_q++) { float sum = 0; for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) { // P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv] * bmm1_scale); // hack for real mla kernel P[s_q * curr_kv_len + s_kv] = std::exp(P[s_q * curr_kv_len + s_kv] * 0.072168784); sum += P[s_q * curr_kv_len + s_kv]; } for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) { P[s_q * curr_kv_len + s_kv] /= sum; } if (return_softmax) { softmax_sum[global_softmax_offset + s_q * num_heads * 2 + h * 2 + 1] = sum; } } // BMM2 for (int s_q = 0; s_q < curr_q_len; s_q++) { for (int d = 0; d < head_size; d++) { output_ptr[s_q * num_heads * head_size + d] = 0; for (int s_kv = 0; s_kv < curr_kv_len; s_kv++) { output_ptr[s_q * num_heads * head_size + d] += static_cast(P[s_q * curr_kv_len + s_kv] * static_cast(v_ptr[s_kv * 2 * num_heads * head_size + d])); } } } } } } // chunked_KV {total_chunk_token, 2, H, D} // KV {total_kv_token, 2, H, D} template void copyRelatedChunkedKV(T* chunked_kv, T* const kv, int chunk_idx, int chunk_size, int batch_size, int num_heads, int64_t* const cu_kv_seq_len, int64_t* const cu_chunked_seq_len, int head_size) { for (int b = 0; b < batch_size; b++) { int src_global_offset = (cu_kv_seq_len[b] + chunk_idx * chunk_size) * 2 * num_heads * head_size; int dst_global_offset = cu_chunked_seq_len[b] * 2 * num_heads * head_size; int copy_length = cu_chunked_seq_len[b + 1] - cu_chunked_seq_len[b]; if (copy_length <= 0) { continue; // skip empty sequences } std::memcpy(chunked_kv + dst_global_offset, kv + src_global_offset, copy_length * 2 * num_heads * head_size * sizeof(T)); } } // chunked_KV {total_chunk_token, 2, H, D} // KV {total_kv_token, 2, H, D} // It will copy the last chunk of KV cache to chunked_KV cache and calculate the cu_chunked_seq_len template void copyFinalChunkedKV(T* chunked_kv, T* const kv, int chunk_size, int batch_size, int num_heads, int64_t* const cu_kv_seq_len, int64_t* cu_chunked_seq_len, int head_size, int64_t* merge_op) { cu_chunked_seq_len[0] = 0; for (int b = 0; b < batch_size; b++) { int curr_kv_len = cu_kv_seq_len[b + 1] - cu_kv_seq_len[b]; int last_chunk_size = curr_kv_len % chunk_size; if (last_chunk_size == 0) { last_chunk_size = chunk_size; // ensure at least one chunk } if (last_chunk_size == curr_kv_len) { merge_op[b] = 2; // no need to merge, just copy } else { merge_op[b] = 1; } cu_chunked_seq_len[b + 1] = cu_chunked_seq_len[b] + last_chunk_size; int global_token_offset = cu_kv_seq_len[b] + curr_kv_len - last_chunk_size; int copy_length = last_chunk_size; if (copy_length <= 0) { printf("copy_length is zero for batch %d, skipping...\n", b); continue; // skip empty sequences } int src_global_offset = global_token_offset * 2 * num_heads * head_size; int dst_global_offset = cu_chunked_seq_len[b] * 2 * num_heads * head_size; std::memcpy(chunked_kv + dst_global_offset, kv + src_global_offset, copy_length * 2 * num_heads * head_size * sizeof(T)); } } template float getTolerance(float scale = 1.f) { float tol = 0.0; if constexpr (std::is_same_v) { tol = 0.1; } else if constexpr (std::is_same_v) { tol = 0.001; } else if constexpr (std::is_same_v) { tol = 0.005; } else if constexpr (std::is_same_v) { tol = 0.05; } // Keep the scale in a sane range return std::max(tol, scale * tol); } }; // namespace template class MlaChunkedPrefillTest : public ::testing::Test { protected: using DataType = typename Typepair::first_type; using TCache = typename Typepair::second_type; static_assert(std::is_same_v || std::is_same_v, "TCache must be either the same type as DataType or __nv_fp8_e4m3"); std::shared_ptr mStream; tensorrt_llm::runtime::BufferManager::ITensorPtr h_kv_cache_tensor{nullptr}, h_kv_cache_tensor_ref{nullptr}, d_kv_cache_tensor{nullptr}, h_compressed_kv_cache_tensor{nullptr}, d_compressed_kv_cache_tensor{nullptr}, h_compressed_offset_tensor{nullptr}, d_compressed_offset_tensor{nullptr}, h_cu_kv_seq_lens{nullptr}, d_cu_kv_seq_lens{nullptr}, h_cu_chunk_lens{nullptr}, d_cu_chunk_lens{nullptr}, h_cu_q_seq_lens{nullptr}, d_cu_q_seq_lens{nullptr}, // for kernel 1 h_compressed_kv_output{nullptr}, d_compressed_kv_output{nullptr}, h_k_pe_output{nullptr}, d_k_pe_output{nullptr}, h_compressed_kv_output_ref{nullptr}, h_k_pe_output_ref{nullptr}, h_kv_scale_quant_orig{nullptr}, d_kv_scale_quant_orig{nullptr}, // for kernel 2 h_kv_tensor{nullptr}, d_kv_tensor{nullptr}, h_k_pe_tensor{nullptr}, d_k_pe_tensor{nullptr}, // for merge attn {kv_full_tensor = kv + k_pe} m_h_q_tensor{nullptr}, m_h_kv_full_tensor{nullptr}, m_h_chunked_kv_tensor{nullptr}, m_h_output_tensor{nullptr}, m_h_softmax_sum_tensor{nullptr}, m_h_softmax_sum_accum_tensor{nullptr}, m_h_output_tensor_ref{nullptr}, m_h_output_tensor_accum{nullptr}, m_d_q_tensor{nullptr}, m_d_kv_full_tensor{nullptr}, m_d_chunked_kv_tensor{nullptr}, m_d_output_tensor{nullptr}, m_d_softmax_sum_tensor{nullptr}, m_d_softmax_sum_accum_tensor{nullptr}, m_d_output_tensor_accum{nullptr}, m_h_merge_op{nullptr}, m_d_merge_op{nullptr}; int mBatchSize{}; int mMaxSeqLen{}; int mMaxQSeqLen{}; int mTotalQLen{}; int mTotalKVLen{}; int mChunkSize{}; int mNumHeads{}; int mLoraSize{}; int mRopeSize{}; int mNopeSize{}; int mMaxGenLength{}; // int mHeadSize{}; int mTokensPerBlock{}; int mMaxBlockPerSeq{}; bool mIsCausalMask{}; std::mt19937 gen; void SetUp() override { if (shouldSkip()) { GTEST_SKIP() << "Skipping mla chunked prefill test"; } mStream = std::make_shared(); gen.seed(42U); } static bool shouldSkip() { return false; } void setDefaultParams() { mBatchSize = 16; // mMaxSeqLen = 128; mChunkSize = 16; mNumHeads = 16; mLoraSize = 512; mRopeSize = 64; mNopeSize = 128; mIsCausalMask = false; mMaxGenLength = 128; mTokensPerBlock = 16; assert(this->mChunkSize % this->mTokensPerBlock == 0); } void memsetZeroHost(tensorrt_llm::runtime::BufferManager::ITensorPtr& tensor) { void* ptr = tensor->data(); std::memset(ptr, 0, tensor->getSizeInBytes()); } template void showHostTensor(tensorrt_llm::runtime::BufferManager::ITensorPtr& tensor) { auto* const ptr = reinterpret_cast(tensor->data()); for (int _ = 0; _ < tensor->getSize(); _++) { std::cout << static_cast(ptr[_]) << " "; } std::cout << std::endl; } int generateRandomSizeSmallerThan(int a) { if (a <= 0) { return 0; } std::uniform_int_distribution<> distrib(0, a - 1); // Generate and return the random number return int{distrib(gen)}; } float generateRandomFloat(float min, float max) { std::uniform_real_distribution dist(min, max); return dist(gen); } template void generateRandomData(T* data, int size) { for (int i = 0; i < size; i++) { data[i] = static_cast(generateRandomFloat(-1.0f, 1.0f)); } } template void fillKVOffsetData(T* arr, size_t size, bool use_both_kv = true, int max_block_per_seq = 0) { if (use_both_kv) { for (int i = 0; i < size; i++) { arr[i] = static_cast(i); } } else { int temp_idx = 0; for (int i = 0; i < size; i++) { bool is_v = (((i / max_block_per_seq) % 2) == 1); if (is_v) { arr[i] = static_cast(0); } else { arr[i] = static_cast(temp_idx); temp_idx++; } } } } template void fillArrayDataWithMod(T* arr, size_t size) { for (int i = 0; i < size; i++) { arr[i] = static_cast(i % 448); } } bool allocateBuffers() { using tensorrt_llm::runtime::BufferManager; using tensorrt_llm::runtime::CudaStream; using tensorrt_llm::runtime::ITensor; using tensorrt_llm::runtime::bufferCast; auto dtype = nvinfer1::DataType::kHALF; if constexpr (std::is_same_v) { dtype = nvinfer1::DataType::kFLOAT; } else if constexpr (std::is_same_v) { dtype = nvinfer1::DataType::kHALF; } else if constexpr (std::is_same_v) { dtype = nvinfer1::DataType::kBF16; } else { return false; } auto cacheType = dtype; if constexpr (std::is_same_v) { cacheType = nvinfer1::DataType::kFP8; this->h_kv_scale_quant_orig = tensorrt_llm::runtime::BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT); this->d_kv_scale_quant_orig = tensorrt_llm::runtime::BufferManager::gpuSync(ITensor::makeShape({1}), nvinfer1::DataType::kFLOAT); auto* kv_scale_quant_orig_ptr = bufferCast(*(this->h_kv_scale_quant_orig)); float kv_scale_orig_quant = 2.0F; kv_scale_quant_orig_ptr[0] = 1.0 / kv_scale_orig_quant; cudaMemcpy(this->d_kv_scale_quant_orig->data(), this->h_kv_scale_quant_orig->data(), this->h_kv_scale_quant_orig->getSizeInBytes(), cudaMemcpyHostToDevice); } // cu lens this->h_cu_kv_seq_lens = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize + 1}), nvinfer1::DataType::kINT64); this->h_cu_chunk_lens = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize + 1}), nvinfer1::DataType::kINT64); this->h_cu_q_seq_lens = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize + 1}), nvinfer1::DataType::kINT64); this->d_cu_kv_seq_lens = tensorrt_llm::runtime::BufferManager::gpuSync( this->h_cu_kv_seq_lens->getShape(), nvinfer1::DataType::kINT64); this->d_cu_chunk_lens = tensorrt_llm::runtime::BufferManager::gpuSync( this->h_cu_chunk_lens->getShape(), nvinfer1::DataType::kINT64); this->d_cu_q_seq_lens = tensorrt_llm::runtime::BufferManager::gpuSync( this->h_cu_q_seq_lens->getShape(), nvinfer1::DataType::kINT64); { this->mMaxSeqLen = 0; this->mMaxQSeqLen = 0; this->mTotalQLen = 0; this->mTotalKVLen = 0; // we only initialize cu_seq_lens auto* cu_kv_seq_lens_ptr = bufferCast(*(this->h_cu_kv_seq_lens)); auto* cu_q_seq_lens_ptr = bufferCast(*(this->h_cu_q_seq_lens)); cu_kv_seq_lens_ptr[0] = 0; cu_q_seq_lens_ptr[0] = 0; for (int i = 0; i < this->mBatchSize; i++) { int temp_seq_len = this->generateRandomSizeSmallerThan(this->mMaxGenLength); if (temp_seq_len == 0) { temp_seq_len = 1; // ensure at least one token } this->mMaxSeqLen = std::max(this->mMaxSeqLen, temp_seq_len); cu_kv_seq_lens_ptr[i + 1] = cu_kv_seq_lens_ptr[i] + temp_seq_len; auto temp_q_seq_len = temp_seq_len % this->mChunkSize; if (temp_q_seq_len == 0) { temp_q_seq_len = this->mChunkSize; // ensure at least one chunk } cu_q_seq_lens_ptr[i + 1] = cu_q_seq_lens_ptr[i] + temp_q_seq_len; this->mMaxQSeqLen = std::max(this->mMaxQSeqLen, temp_q_seq_len); this->mTotalQLen += temp_q_seq_len; this->mTotalKVLen += temp_seq_len; } cudaMemcpy(this->d_cu_kv_seq_lens->data(), this->h_cu_kv_seq_lens->data(), this->h_cu_kv_seq_lens->getSizeInBytes(), cudaMemcpyHostToDevice); cudaMemcpy(this->d_cu_q_seq_lens->data(), this->h_cu_q_seq_lens->data(), this->h_cu_q_seq_lens->getSizeInBytes(), cudaMemcpyHostToDevice); #ifdef TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG this->showHostTensor(this->h_cu_q_seq_lens); this->showHostTensor(this->h_cu_kv_seq_lens); #endif } // kv cache this->mMaxBlockPerSeq = (this->mMaxSeqLen + this->mTokensPerBlock - 1) / this->mTokensPerBlock; int maxChunkBlockPerSeq = (this->mChunkSize + this->mTokensPerBlock - 1) / this->mTokensPerBlock; this->h_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize, 2, maxChunkBlockPerSeq, this->mNumHeads, this->mTokensPerBlock, this->mNopeSize + this->mRopeSize}), dtype); this->h_kv_cache_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize, 2, maxChunkBlockPerSeq, this->mNumHeads, this->mTokensPerBlock, this->mNopeSize + this->mRopeSize}), dtype); this->h_compressed_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize, 2, this->mMaxBlockPerSeq, this->mNumHeads, this->mTokensPerBlock, this->mLoraSize + this->mRopeSize}), cacheType); this->h_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize, 2, this->mMaxBlockPerSeq + 1}), nvinfer1::DataType::kINT32); this->d_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_kv_cache_tensor->getShape(), dtype); this->d_compressed_kv_cache_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_compressed_kv_cache_tensor->getShape(), cacheType); this->d_compressed_offset_tensor = tensorrt_llm::runtime::BufferManager::gpuSync( this->h_compressed_offset_tensor->getShape(), nvinfer1::DataType::kINT32); { auto* compressed_kv_cache_ptr = bufferCast(*(this->h_compressed_kv_cache_tensor)); auto* offset_ptr = bufferCast(*(this->h_compressed_offset_tensor)); this->memsetZeroHost(this->h_kv_cache_tensor); this->memsetZeroHost(this->h_kv_cache_tensor_ref); this->fillArrayDataWithMod(compressed_kv_cache_ptr, this->h_compressed_kv_cache_tensor->getSize()); this->fillKVOffsetData( offset_ptr, this->h_compressed_offset_tensor->getSize(), false, this->mMaxBlockPerSeq); cudaMemcpy(this->d_kv_cache_tensor->data(), this->h_kv_cache_tensor->data(), this->h_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); cudaMemcpy(this->d_compressed_kv_cache_tensor->data(), this->h_compressed_kv_cache_tensor->data(), this->h_compressed_kv_cache_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); cudaMemcpy(this->d_compressed_offset_tensor->data(), this->h_compressed_offset_tensor->data(), this->h_compressed_offset_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); } // tensor // kv, k_pe for invokeMLALoadChunkedKV (kernel 1) this->h_compressed_kv_output = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mLoraSize}), dtype); this->h_k_pe_output = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype); this->h_compressed_kv_output_ref = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mLoraSize}), dtype); this->h_k_pe_output_ref = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype); this->d_compressed_kv_output = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_compressed_kv_output->getShape(), dtype); this->d_k_pe_output = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_k_pe_output->getShape(), dtype); { this->memsetZeroHost(this->h_compressed_kv_output); this->memsetZeroHost(this->h_k_pe_output); this->memsetZeroHost(this->h_compressed_kv_output_ref); this->memsetZeroHost(this->h_k_pe_output_ref); cudaMemcpy(this->d_compressed_kv_output->data(), this->h_compressed_kv_output->data(), this->h_compressed_kv_output->getSizeInBytes(), cudaMemcpyHostToDevice); cudaMemcpy(this->d_k_pe_output->data(), this->h_k_pe_output->data(), this->h_k_pe_output->getSizeInBytes(), cudaMemcpyHostToDevice); } // kv, k_pe for invokeMLASetChunkedKV (kernel 2) this->h_kv_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize * this->mChunkSize, 2, this->mNumHeads, this->mNopeSize}), dtype); this->h_k_pe_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize * this->mChunkSize, 1, this->mRopeSize}), dtype); this->d_kv_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_kv_tensor->getShape(), dtype); this->d_k_pe_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->h_k_pe_tensor->getShape(), dtype); { auto* kv_ptr = bufferCast(*(this->h_kv_tensor)); auto* k_pe_ptr = bufferCast(*(this->h_k_pe_tensor)); fillArrayDataWithMod(kv_ptr, h_kv_tensor->getSize()); fillArrayDataWithMod(k_pe_ptr, h_k_pe_tensor->getSize()); cudaMemcpyAsync(d_kv_tensor->data(), h_kv_tensor->data(), h_kv_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); cudaMemcpyAsync(d_k_pe_tensor->data(), h_k_pe_tensor->data(), h_k_pe_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); cudaStreamSynchronize(mStream->get()); } // invokeMergeAttnWithSoftmax, we just ignore rope_size here for simplicity this->m_h_q_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype); this->m_h_kv_full_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mTotalKVLen, 2, this->mNumHeads, this->mNopeSize}), dtype); this->m_h_chunked_kv_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize * this->mChunkSize, 2, this->mNumHeads, this->mNopeSize}), dtype); this->m_h_output_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype); this->m_h_softmax_sum_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({2, this->mTotalQLen, this->mNumHeads}), nvinfer1::DataType::kFLOAT); this->m_h_softmax_sum_accum_tensor = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({2, this->mTotalQLen, this->mNumHeads}), nvinfer1::DataType::kFLOAT); this->m_h_output_tensor_ref = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype); this->m_h_output_tensor_accum = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mTotalQLen, this->mNumHeads, this->mNopeSize}), dtype); this->m_h_merge_op = tensorrt_llm::runtime::BufferManager::pinned( ITensor::makeShape({this->mBatchSize}), nvinfer1::DataType::kINT64); this->m_d_q_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_q_tensor->getShape(), dtype); this->m_d_kv_full_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_kv_full_tensor->getShape(), dtype); this->m_d_chunked_kv_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_chunked_kv_tensor->getShape(), dtype); this->m_d_output_tensor = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_output_tensor->getShape(), dtype); this->m_d_softmax_sum_tensor = tensorrt_llm::runtime::BufferManager::gpuSync( this->m_h_softmax_sum_tensor->getShape(), nvinfer1::DataType::kFLOAT); this->m_d_softmax_sum_accum_tensor = tensorrt_llm::runtime::BufferManager::gpuSync( this->m_h_softmax_sum_accum_tensor->getShape(), nvinfer1::DataType::kFLOAT); this->m_d_output_tensor_accum = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_output_tensor_accum->getShape(), dtype); this->m_d_merge_op = tensorrt_llm::runtime::BufferManager::gpuSync(this->m_h_merge_op->getShape(), nvinfer1::DataType::kINT64); { auto* q_ptr = bufferCast(*(this->m_h_q_tensor)); auto* kv_ptr = bufferCast(*(this->m_h_kv_full_tensor)); generateRandomData(q_ptr, m_h_q_tensor->getSize()); generateRandomData(kv_ptr, m_h_kv_full_tensor->getSize()); this->memsetZeroHost(m_h_chunked_kv_tensor); this->memsetZeroHost(m_h_output_tensor); this->memsetZeroHost(m_h_softmax_sum_tensor); this->memsetZeroHost(m_h_softmax_sum_accum_tensor); this->memsetZeroHost(m_h_output_tensor_ref); this->memsetZeroHost(m_h_output_tensor_accum); // Copy data to device cudaMemcpyAsync(m_d_q_tensor->data(), m_h_q_tensor->data(), m_h_q_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); cudaMemcpyAsync(m_d_kv_full_tensor->data(), m_h_kv_full_tensor->data(), m_h_kv_full_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); cudaMemcpyAsync(m_d_chunked_kv_tensor->data(), m_h_chunked_kv_tensor->data(), m_h_chunked_kv_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); cudaMemcpyAsync(m_d_output_tensor->data(), m_h_output_tensor->data(), m_h_output_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); cudaMemcpyAsync(m_d_softmax_sum_tensor->data(), m_h_softmax_sum_tensor->data(), m_h_softmax_sum_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); cudaMemcpyAsync(m_d_softmax_sum_accum_tensor->data(), m_h_softmax_sum_accum_tensor->data(), m_h_softmax_sum_accum_tensor->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); cudaMemcpyAsync(m_d_output_tensor_accum->data(), m_h_output_tensor_accum->data(), m_h_output_tensor_accum->getSizeInBytes(), cudaMemcpyHostToDevice, mStream->get()); cudaStreamSynchronize(mStream->get()); } return true; } void PerformNormalAttention() { using tensorrt_llm::runtime::bufferCast; auto* q_ptr = bufferCast(*(this->m_h_q_tensor)); auto* kv_ptr = bufferCast(*(this->m_h_kv_full_tensor)); auto* output_ptr = bufferCast(*(this->m_h_output_tensor_ref)); auto* cu_q_seq_lens_ptr = bufferCast(*(this->h_cu_q_seq_lens)); auto* cu_kv_seq_lens_ptr = bufferCast(*(this->h_cu_kv_seq_lens)); selfAttentionRef(output_ptr, q_ptr, kv_ptr, this->mBatchSize, this->mNumHeads, cu_q_seq_lens_ptr, cu_kv_seq_lens_ptr, this->mNopeSize, false, nullptr, this->mIsCausalMask); } void PerformMergedAttention() { using tensorrt_llm::runtime::bufferCast; auto* h_q_ptr = bufferCast(*(this->m_h_q_tensor)); auto* h_kv_ptr = bufferCast(*(this->m_h_kv_full_tensor)); auto* h_chunked_kv_ptr = bufferCast(*(this->m_h_chunked_kv_tensor)); auto* h_output_ptr = bufferCast(*(this->m_h_output_tensor)); auto* h_output_accum_ptr = bufferCast(*(this->m_h_output_tensor_accum)); auto* h_softmax_sum_ptr = bufferCast(*(this->m_h_softmax_sum_tensor)); auto* h_softmax_sum_accum_ptr = bufferCast(*(this->m_h_softmax_sum_accum_tensor)); auto* h_cu_q_seq_lens_ptr = bufferCast(*(this->h_cu_q_seq_lens)); auto* h_cu_kv_seq_lens_ptr = bufferCast(*(this->h_cu_kv_seq_lens)); auto* h_cu_chunk_lens_ptr = bufferCast(*(this->h_cu_chunk_lens)); auto* h_merge_op = bufferCast(*(this->m_h_merge_op)); auto* d_kv_ptr = bufferCast(*(this->m_d_kv_full_tensor)); auto* d_chunked_kv_ptr = bufferCast(*(this->m_d_chunked_kv_tensor)); auto* d_softmax_sum_ptr = bufferCast(*(this->m_d_softmax_sum_tensor)); auto* d_softmax_sum_accum_ptr = bufferCast(*(this->m_d_softmax_sum_accum_tensor)); auto* d_output_ptr = bufferCast(*(this->m_d_output_tensor)); auto* d_output_accum_ptr = bufferCast(*(this->m_d_output_tensor_accum)); auto* d_merge_op = bufferCast(*(this->m_d_merge_op)); auto* d_cu_q_seq_lens_ptr = bufferCast(*(this->d_cu_q_seq_lens)); int const loop_count = (this->mMaxSeqLen + this->mChunkSize - 1) / this->mChunkSize; // do not apply mask for (int _ = 0; _ < loop_count - 1; _++) { // get chunked len for each request this->PrepareChunkedLen(_); cudaMemcpy(d_merge_op, h_merge_op, this->m_h_merge_op->getSizeInBytes(), cudaMemcpyHostToDevice); // copy related kv chunk data copyRelatedChunkedKV(h_chunked_kv_ptr, h_kv_ptr, _, this->mChunkSize, this->mBatchSize, this->mNumHeads, h_cu_kv_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize); // attention selfAttentionRef(h_output_ptr, h_q_ptr, h_chunked_kv_ptr, this->mBatchSize, this->mNumHeads, h_cu_q_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize, true, h_softmax_sum_ptr, false); // merge attention // copy curr_attn and softmax_sum to device cudaMemcpy(d_softmax_sum_ptr, h_softmax_sum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); cudaMemcpy(d_output_ptr, h_output_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); // merge softmax tensorrt_llm::kernels::invokeMergeAttnWithSoftmax(d_output_accum_ptr, d_softmax_sum_accum_ptr, d_output_accum_ptr, d_softmax_sum_accum_ptr, d_output_ptr, d_softmax_sum_ptr, this->mBatchSize, d_cu_q_seq_lens_ptr, this->mMaxQSeqLen, d_merge_op, this->mNumHeads, this->mNopeSize, mStream->get()); cudaStreamSynchronize(mStream->get()); // copy merged softmax sum back to host cudaMemcpy(h_softmax_sum_accum_ptr, d_softmax_sum_accum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost); cudaMemcpy(h_output_accum_ptr, d_output_accum_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost); } // final round, apply causal mask. // copy the last chunked kv data copyFinalChunkedKV(h_chunked_kv_ptr, h_kv_ptr, this->mChunkSize, this->mBatchSize, this->mNumHeads, h_cu_kv_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize, h_merge_op); cudaMemcpy(d_merge_op, h_merge_op, this->m_h_merge_op->getSizeInBytes(), cudaMemcpyHostToDevice); #ifdef TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG std::cout << "merge op: "; this->showHostTensor(this->m_h_merge_op); std::cout << "cu chunk lens: "; this->showHostTensor(this->h_cu_chunk_lens); #endif // attention selfAttentionRef(h_output_ptr, h_q_ptr, h_chunked_kv_ptr, this->mBatchSize, this->mNumHeads, h_cu_q_seq_lens_ptr, h_cu_chunk_lens_ptr, this->mNopeSize, true, h_softmax_sum_ptr, this->mIsCausalMask); // merge attention // copy curr_attn and softmax_sum to device cudaMemcpy(d_softmax_sum_ptr, h_softmax_sum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); cudaMemcpy(d_output_ptr, h_output_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyHostToDevice); tensorrt_llm::kernels::invokeMergeAttnWithSoftmax(d_output_accum_ptr, d_softmax_sum_accum_ptr, d_output_accum_ptr, d_softmax_sum_accum_ptr, d_output_ptr, d_softmax_sum_ptr, this->mBatchSize, d_cu_q_seq_lens_ptr, this->mMaxQSeqLen, d_merge_op, this->mNumHeads, this->mNopeSize, mStream->get()); cudaStreamSynchronize(mStream->get()); // copy merged softmax sum back to host cudaMemcpy(h_softmax_sum_accum_ptr, d_softmax_sum_accum_ptr, this->m_h_softmax_sum_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost); cudaMemcpy( h_output_accum_ptr, d_output_accum_ptr, this->m_h_output_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost); sync_check_cuda_error(mStream->get()); } void PrepareChunkedLen(int chunk_idx) { using tensorrt_llm::runtime::bufferCast; auto* h_merge_op = bufferCast(*(this->m_h_merge_op)); auto* h_cu_q_seq_lens_ptr = bufferCast(*(this->h_cu_q_seq_lens)); auto* h_cu_kv_seq_lens_ptr = bufferCast(*(this->h_cu_kv_seq_lens)); auto* h_cu_chunk_lens_ptr = bufferCast(*(this->h_cu_chunk_lens)); h_cu_chunk_lens_ptr[0] = 0; for (int b = 0; b < this->mBatchSize; b++) { int curr_kv_len = h_cu_kv_seq_lens_ptr[b + 1] - h_cu_kv_seq_lens_ptr[b]; int used_kv_len = chunk_idx * this->mChunkSize; int curr_chunk_len = std::min(this->mChunkSize, curr_kv_len - used_kv_len); if (curr_chunk_len != this->mChunkSize) { // last chunk, we should skip it. curr_chunk_len = 0; } else { if (used_kv_len + curr_chunk_len == curr_kv_len) { // last chunk, we should skip it. curr_chunk_len = 0; } } h_cu_chunk_lens_ptr[b + 1] = h_cu_chunk_lens_ptr[b] + curr_chunk_len; if (chunk_idx == 0 && curr_chunk_len > 0) { h_merge_op[b] = 2; // only copy result } else if (curr_chunk_len > 0) { h_merge_op[b] = 1; // merge result } else { h_merge_op[b] = 0; // skip } } #ifdef TRTLLM_MLA_CHUNKED_PREFILL_TEST_DBG std::cout << "merge op: "; this->showHostTensor(this->m_h_merge_op); std::cout << "cu chunk lens: "; this->showHostTensor(this->h_cu_chunk_lens); #endif } void PerformLoadChunkedKVRef(int chunk_idx) { using tensorrt_llm::runtime::bufferCast; auto* compressed_kv_output_ptr = bufferCast(*(this->h_compressed_kv_output_ref)); auto* k_pe_output_ptr = bufferCast(*(this->h_k_pe_output_ref)); auto* compressed_kv_cache_ptr = bufferCast(*(this->h_compressed_kv_cache_tensor)); auto* offset_ptr = bufferCast(*(this->h_compressed_offset_tensor)); auto* h_cu_chunk_lens_ptr = bufferCast(*(this->h_cu_chunk_lens)); float* kv_scale_quant_orig_ptr = nullptr; if constexpr (std::is_same_v) { kv_scale_quant_orig_ptr = bufferCast(*(this->h_kv_scale_quant_orig)); } tensorrt_llm::kernels::KVBlockArray kv_cache(this->mBatchSize, this->mMaxBlockPerSeq, this->mTokensPerBlock, sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr, reinterpret_cast(offset_ptr)); this->PrepareChunkedLen(chunk_idx); loadChunkedKVKernelRef(compressed_kv_output_ptr, k_pe_output_ptr, kv_cache, this->mBatchSize, h_cu_chunk_lens_ptr, this->mLoraSize, this->mRopeSize, this->mChunkSize, chunk_idx, kv_scale_quant_orig_ptr); } void PreformLoadChunkedKV(int chunk_idx) { using tensorrt_llm::runtime::bufferCast; auto* compressed_kv_output_ptr = bufferCast(*(this->d_compressed_kv_output)); auto* k_pe_output_ptr = bufferCast(*(this->d_k_pe_output)); auto* compressed_kv_cache_ptr = bufferCast(*(this->d_compressed_kv_cache_tensor)); auto* offset_ptr = bufferCast(*(this->d_compressed_offset_tensor)); auto* d_cu_chunk_lens_ptr = bufferCast(*(this->d_cu_chunk_lens)); float* kv_scale_quant_orig_ptr = nullptr; if constexpr (std::is_same_v) { kv_scale_quant_orig_ptr = bufferCast(*(this->d_kv_scale_quant_orig)); } tensorrt_llm::kernels::KVBlockArray kv_cache(this->mBatchSize, this->mMaxBlockPerSeq, this->mTokensPerBlock, sizeof(TCache) * 1 * (this->mLoraSize + this->mRopeSize), 0, 0, 0, 0, compressed_kv_cache_ptr, nullptr, reinterpret_cast(offset_ptr)); this->PrepareChunkedLen(chunk_idx); // copy cu chunk lens to device cudaMemcpy(this->d_cu_chunk_lens->data(), this->h_cu_chunk_lens->data(), this->h_cu_chunk_lens->getSizeInBytes(), cudaMemcpyHostToDevice); tensorrt_llm::kernels::invokeMLALoadChunkedKV(compressed_kv_output_ptr, k_pe_output_ptr, kv_cache, this->mBatchSize, d_cu_chunk_lens_ptr, this->mLoraSize, this->mRopeSize, this->mChunkSize, chunk_idx, kv_scale_quant_orig_ptr, mStream->get()); cudaStreamSynchronize(this->mStream->get()); // copy result back to host cudaMemcpy(this->h_compressed_kv_output->data(), compressed_kv_output_ptr, this->h_compressed_kv_output->getSizeInBytes(), cudaMemcpyDeviceToHost); cudaMemcpy(this->h_k_pe_output->data(), k_pe_output_ptr, this->h_k_pe_output->getSizeInBytes(), cudaMemcpyDeviceToHost); sync_check_cuda_error(this->mStream->get()); } void PerformSetChunkedKVRef() { using tensorrt_llm::runtime::bufferCast; auto* kv_ptr = bufferCast(*(this->h_kv_tensor)); auto* k_pe_ptr = bufferCast(*(this->h_k_pe_tensor)); auto* kv_cache_ptr = bufferCast(*(this->h_kv_cache_tensor_ref)); auto* cu_chunked_seq_lens_ptr = bufferCast(*(this->h_cu_chunk_lens)); this->PrepareChunkedLen(0); setChunkedKVCacheForMLAKernelRef(kv_cache_ptr, kv_ptr, k_pe_ptr, this->mBatchSize, cu_chunked_seq_lens_ptr, this->mChunkSize, this->mNumHeads, this->mNopeSize, this->mRopeSize, this->mTokensPerBlock); } void PerformSetChunkedKV() { using tensorrt_llm::runtime::bufferCast; auto* kv_ptr = bufferCast(*(this->d_kv_tensor)); auto* k_pe_ptr = bufferCast(*(this->d_k_pe_tensor)); auto* kv_cache_ptr = bufferCast(*(this->d_kv_cache_tensor)); auto* cu_chunked_seq_lens_ptr = bufferCast(*(this->d_cu_chunk_lens)); this->PrepareChunkedLen(0); // copy cu chunk lens to device cudaMemcpy(this->d_cu_chunk_lens->data(), this->h_cu_chunk_lens->data(), this->h_cu_chunk_lens->getSizeInBytes(), cudaMemcpyHostToDevice); tensorrt_llm::kernels::invokeMLASetChunkedKV(kv_cache_ptr, kv_ptr, k_pe_ptr, this->mBatchSize, this->mChunkSize, this->mNumHeads, this->mNopeSize, this->mRopeSize, cu_chunked_seq_lens_ptr, this->mTokensPerBlock, mStream->get()); cudaStreamSynchronize(this->mStream->get()); // copy result back to host cudaMemcpy(this->h_kv_cache_tensor->data(), kv_cache_ptr, this->h_kv_cache_tensor->getSizeInBytes(), cudaMemcpyDeviceToHost); sync_check_cuda_error(this->mStream->get()); } }; using MLATypes = ::testing::Types, std::pair<__nv_bfloat16, __nv_bfloat16>, std::pair, std::pair, std::pair<__nv_bfloat16, __nv_fp8_e4m3>, std::pair>; TYPED_TEST_SUITE(MlaChunkedPrefillTest, MLATypes); TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedPrefillDefault) { using tensorrt_llm::runtime::bufferCast; using DataType = typename TestFixture::DataType; using TCache = typename TestFixture::TCache; if constexpr (std::is_same_v) { this->setDefaultParams(); this->allocateBuffers(); sync_check_cuda_error(this->mStream->get()); bool allEqual{true}; this->PerformNormalAttention(); sync_check_cuda_error(this->mStream->get()); this->PerformMergedAttention(); sync_check_cuda_error(this->mStream->get()); // check result auto* output_ptr = bufferCast(*(this->m_h_output_tensor_accum)); auto* output_ref_ptr = bufferCast(*(this->m_h_output_tensor_ref)); for (int i = 0; i < this->m_h_output_tensor->getSize(); i++) { if (std::abs(static_cast(output_ptr[i]) - static_cast(output_ref_ptr[i])) > getTolerance(output_ptr[i])) { std::cout << "Output mismatch at index " << i << ": " << "expected " << static_cast(output_ref_ptr[i]) << ", got " << static_cast(output_ptr[i]) << std::endl; allEqual = false; break; } } ASSERT_TRUE(allEqual); } } TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedPrefillCausalMask) { using tensorrt_llm::runtime::bufferCast; using DataType = typename TestFixture::DataType; using TCache = typename TestFixture::TCache; if constexpr (std::is_same_v) { this->setDefaultParams(); this->mIsCausalMask = true; this->allocateBuffers(); sync_check_cuda_error(this->mStream->get()); bool allEqual{true}; this->PerformNormalAttention(); sync_check_cuda_error(this->mStream->get()); this->PerformMergedAttention(); sync_check_cuda_error(this->mStream->get()); // check result auto* output_ptr = bufferCast(*(this->m_h_output_tensor_accum)); auto* output_ref_ptr = bufferCast(*(this->m_h_output_tensor_ref)); for (int i = 0; i < this->m_h_output_tensor->getSize(); i++) { if (std::abs(static_cast(output_ptr[i]) - static_cast(output_ref_ptr[i])) > getTolerance(output_ptr[i])) { std::cout << "Output mismatch at index " << i << ": " << "expected " << static_cast(output_ref_ptr[i]) << ", got " << static_cast(output_ptr[i]) << std::endl; allEqual = false; break; } } ASSERT_TRUE(allEqual); } } TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedLoad) { using tensorrt_llm::runtime::bufferCast; using DataType = typename TestFixture::DataType; this->setDefaultParams(); this->allocateBuffers(); sync_check_cuda_error(this->mStream->get()); bool allEqual{true}; int const loop_count = (this->mMaxSeqLen + this->mChunkSize - 1) / this->mChunkSize; for (int _ = 0; _ < loop_count - 1; _++) { this->PerformLoadChunkedKVRef(_); sync_check_cuda_error(this->mStream->get()); this->PreformLoadChunkedKV(_); sync_check_cuda_error(this->mStream->get()); // check result auto* compressed_kv_output_ptr = bufferCast(*(this->h_compressed_kv_output_ref)); auto* compressed_kv_output_ref_ptr = bufferCast(*(this->h_compressed_kv_output)); auto* k_pe_output_ptr = bufferCast(*(this->h_k_pe_output)); auto* k_pe_output_ref_ptr = bufferCast(*(this->h_k_pe_output_ref)); // check kv for (int i = 0; i < this->h_compressed_kv_output->getSize(); i++) { if (std::abs(static_cast(compressed_kv_output_ptr[i]) - static_cast(compressed_kv_output_ref_ptr[i])) > getTolerance(compressed_kv_output_ptr[i])) { std::cout << "Compressed KV output mismatch at loop: " << _ << " index " << i << ": " << "expected " << static_cast(compressed_kv_output_ref_ptr[i]) << ", got " << static_cast(compressed_kv_output_ptr[i]) << std::endl; allEqual = false; break; } } // check k_pe for (int i = 0; i < this->h_k_pe_output->getSize(); i++) { if (std::abs(static_cast(k_pe_output_ptr[i]) - static_cast(k_pe_output_ref_ptr[i])) > getTolerance(k_pe_output_ptr[i])) { std::cout << "kpe mismatch at loop: " << _ << " index " << i << ": " << "expected " << static_cast(k_pe_output_ref_ptr[i]) << ", got " << static_cast(k_pe_output_ptr[i]) << std::endl; allEqual = false; break; } } } ASSERT_TRUE(allEqual); } TYPED_TEST(MlaChunkedPrefillTest, MlaChunkedSet) { using tensorrt_llm::runtime::bufferCast; using DataType = typename TestFixture::DataType; using TCache = typename TestFixture::TCache; if constexpr (std::is_same_v) { this->setDefaultParams(); this->allocateBuffers(); sync_check_cuda_error(this->mStream->get()); bool allEqual{true}; this->PerformSetChunkedKVRef(); sync_check_cuda_error(this->mStream->get()); this->PerformSetChunkedKV(); sync_check_cuda_error(this->mStream->get()); // check result auto* kv_cache_ptr = bufferCast(*(this->h_kv_cache_tensor)); auto* kv_cache_ptr_ref = bufferCast(*(this->h_kv_cache_tensor_ref)); for (int i = 0; i < this->h_kv_cache_tensor->getSize(); i++) { if (std::abs(static_cast(kv_cache_ptr[i]) - static_cast(kv_cache_ptr_ref[i])) > getTolerance(kv_cache_ptr[i])) { std::cout << "KV cache mismatch at index " << i << ": " << "expected " << static_cast(kv_cache_ptr_ref[i]) << ", got " << static_cast(kv_cache_ptr[i]) << std::endl; allEqual = false; break; } } ASSERT_TRUE(allEqual); } }