diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 51bccb7e6e..8a7f3dc660 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -24,6 +24,7 @@ #include "tensorrt_llm/kernels/gptKernels.h" #include "tensorrt_llm/kernels/kvCacheUtils.h" #include "tensorrt_llm/kernels/multiHeadAttentionCommon.h" +#include "tensorrt_llm/kernels/sparseAttentionKernels.h" #include "tensorrt_llm/kernels/unfusedAttentionKernels.h" #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" @@ -287,6 +288,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams.output_sf = generationsParams.context_buf_sf; xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale; xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf; + // Parameters for sparse attention + xqaParams.sparse_params = mRuntimeSparseAttentionParams; + xqaParams.use_sparse_attention = useTllmGenSparseAttention(); // Cross attention parameters. xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths; @@ -813,7 +817,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t } size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t max_num_seq, - int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept + int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept { if (max_num_tokens == 0) { @@ -909,11 +913,15 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32 size_t xqa_workspace_size = 0; if (mEnableXQA) { - int const XQA_NUM_BUFFERS = 7; + int const XQA_NUM_BUFFERS = 8; size_t xqa_workspaces[XQA_NUM_BUFFERS]; size_t const cu_seqlens_size = sizeof(int) * (batch_beam + 1); size_t const cu_kv_seqlens_size = sizeof(int) * (batch_beam + 1); size_t const rotary_inv_freq_size = sizeof(float) * batch_beam * mRotaryEmbeddingDim / 2; + // Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets. + size_t const sparse_attn_cache_size = useTllmGenSparseAttention() + ? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence) * mNumKVHeads + : 0; xqa_workspaces[0] = cu_seqlens_size; xqa_workspaces[1] = cu_kv_seqlens_size; xqa_workspaces[2] = rotary_inv_freq_size; @@ -922,7 +930,8 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32 // Scales used for trtllm-gen kernels. xqa_workspaces[4] = sizeof(float) * 2; xqa_workspaces[5] = sizeof(float); - xqa_workspaces[6] = mXqaDispatcher->getWorkspaceSize( + xqa_workspaces[6] = sparse_attn_cache_size; + xqa_workspaces[7] = mXqaDispatcher->getWorkspaceSize( std::min(mSpecDecodingMaxGenerationLength * max_num_seq, max_num_tokens)); xqa_workspace_size = tc::calculateTotalWorkspaceSize(xqa_workspaces, XQA_NUM_BUFFERS, mXqaDispatcher->getWorkspaceAlignment()); @@ -1647,6 +1656,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea preprocessingParams.spec_decoding_position_offsets = nullptr; preprocessingParams.logn_scaling = params.logn_scaling_ptr; + // Sparse KV write + preprocessingParams.sparse_kv_indices = mRuntimeSparseAttentionParams.sparse_kv_indices; + preprocessingParams.sparse_kv_offsets = mRuntimeSparseAttentionParams.sparse_kv_offsets; + // Scalars preprocessingParams.batch_size = params.batch_size; preprocessingParams.max_input_seq_len = params.input_seq_length; @@ -1676,6 +1689,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea preprocessingParams.rotary_vision_start = mVisionStart; preprocessingParams.rotary_vision_length = mVisionLength; + preprocessingParams.is_last_chunk + = !mAttentionChunkSize.has_value() || (params.input_seq_length == params.max_past_kv_length); { std::string const beforeRopeStr = "ctx attention before RoPE at layer " + std::to_string(mLayerIdx); @@ -1841,6 +1856,12 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea gatherInBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream); sync_check_cuda_error(stream); } + + if (!mIsMLAEnabled) // Only for non-MLA attention + { + invokeKvCachePostprocessing(preprocessingParams, stream); + sync_check_cuda_error(stream); + } } else { diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index f33194c02f..b8aef87e85 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -26,6 +26,7 @@ #include "tensorrt_llm/kernels/gptKernels.h" #include "tensorrt_llm/kernels/kvCacheUtils.h" #include "tensorrt_llm/kernels/mlaKernels.h" +#include "tensorrt_llm/kernels/sparseAttentionKernels.h" #include "tensorrt_llm/kernels/xqaDispatcher.h" #include #include @@ -55,7 +56,7 @@ public: int32_t cross_kv_length = 0, int32_t max_num_tokens = 0) const noexcept; // total_num_seq is the sum of beam_width for multiple requests [[nodiscard]] size_t getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t total_num_seq, - int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept; + int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept; template class EnqueueParams @@ -156,14 +157,20 @@ public: ss << "max_cyclic_attention_window_size: " << this->max_cyclic_attention_window_size << std::endl; ss << "can_use_one_more_block: " << (this->can_use_one_more_block ? "true" : "false") << std::endl; ss << "sink_token_length: " << this->sink_token_length << std::endl; - ss << "context_lengths: " - << *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32, - runtime::ITensor::makeShape({batch_size}))) - << std::endl; - ss << "sequence_lengths: " - << *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32, - runtime::ITensor::makeShape({batch_size}))) - << std::endl; + if (this->context_lengths && batch_size > 0) + { + ss << "context_lengths: " + << *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32, + runtime::ITensor::makeShape({batch_size}))) + << std::endl; + } + if (this->sequence_lengths && batch_size > 0) + { + ss << "sequence_lengths: " + << *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32, + runtime::ITensor::makeShape({batch_size}))) + << std::endl; + } ss << "kv_scale_orig_quant: " << this->kv_scale_orig_quant << std::endl; ss << "kv_scale_quant_orig: " << this->kv_scale_quant_orig << std::endl; ss << "attention_output_orig_quant: " << this->attention_output_orig_quant << std::endl; @@ -348,6 +355,16 @@ public: return mIsMLAEnabled; } + [[nodiscard]] bool useSparseAttention() const + { + return mUseSparseAttention && mPagedKVCache && mEnableXQA; + } + + [[nodiscard]] bool useTllmGenSparseAttention() const + { + return mUseTllmGenSparseAttention && useSparseAttention(); + } + [[nodiscard]] int smVersion() const { return mSM; @@ -427,6 +444,8 @@ public: bool mIsMLAEnabled = false; bool mIsGenerationMLA = false; bool mUseGenFlashMLA = false; + bool mUseSparseAttention = false; + bool mUseTllmGenSparseAttention = false; tensorrt_llm::kernels::MlaMetaParams mMLAParams; int mCpSize = 1; int mCpRank = 0; @@ -454,6 +473,8 @@ public: // Whether to fuse FP4 quant into attention kernel. bool mFuseFp4Quant = false; + kernels::SparseAttentionParams mRuntimeSparseAttentionParams; + // This is implementation details which we want to save when serializing, but not expose as // a plugin field or a constructor parameter int32_t mNbMultiBlockSemaphores = 0; @@ -473,10 +494,11 @@ public: mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA, mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, - mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, - mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, - mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, - mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1)); + mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mUseTllmGenSparseAttention, + mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, + mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, + mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant, + mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1)); }; private: diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h index a1a3f049d3..f2dcb7a858 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h @@ -233,6 +233,8 @@ struct XQALaunchParam float* bmm2_scale_ptr = nullptr; int32_t* semaphores = nullptr; void* scratch = nullptr; + void* sparse_kv_block_offsets = nullptr; + int32_t* sparse_seq_lengths = nullptr; }; // Setup launch params and ioScratch. ioScratch is for RoPE and output type conversion. @@ -266,6 +268,9 @@ void buildXQALaunchParams(XQALaunchParam& launchParams, void*& in const size_t cu_kv_seqlens_size = sizeof(int) * (batch_beam_size + 1); const size_t rotary_inv_freq_size = sizeof(float) * batch_beam_size * params.rotary_embedding_dim / 2; const size_t tokens_info_size = sizeof(int2) * params.total_num_input_tokens; + const size_t kv_block_offsets_size + = sizeof(int) * batch_beam_size * 2 * params.max_blocks_per_sequence * params.num_kv_heads; + const size_t seq_lengths_size = sizeof(int) * batch_beam_size * params.num_kv_heads; launchParams.cu_seq_lens = reinterpret_cast(workspace); workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, cu_seqlens_size); launchParams.cu_kv_seq_lens = reinterpret_cast(workspace); @@ -281,6 +286,14 @@ void buildXQALaunchParams(XQALaunchParam& launchParams, void*& in workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm1_scale_size); launchParams.bmm2_scale_ptr = reinterpret_cast(workspace); workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm2_scale_size); + // Used for block sparse attention + if (params.use_sparse_attention) + { + launchParams.sparse_kv_block_offsets = reinterpret_cast(workspace); + workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, kv_block_offsets_size); + launchParams.sparse_seq_lengths = reinterpret_cast(workspace); + workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, seq_lengths_size); + } inputScratch = workspace; if (hasOutputScratch) { diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h index fcf8ab3851..f5f712a6d8 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h @@ -17,6 +17,7 @@ #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/kernels/gptKernels.h" #include "tensorrt_llm/kernels/multiHeadAttentionCommon.h" +#include "tensorrt_llm/kernels/sparseAttentionKernels.h" namespace tensorrt_llm { @@ -109,6 +110,10 @@ struct XQAParams // for cross attention int32_t const* encoder_input_lengths = nullptr; + // sparse attention parameters + SparseAttentionParams sparse_params; + bool use_sparse_attention = false; + cudaStream_t stream = 0; std::string toString() const @@ -179,6 +184,8 @@ struct XQAParams << "is_fp8_output :" << (is_fp8_output ? "true" : "false") << std ::endl << "fp8_out_scale :" << fp8_out_scale << std ::endl << "encoder_input_lengths: " << encoder_input_lengths << std::endl + << "sparse_params: " << sparse_params.toString() << std::endl + << "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl << "stream :" << stream; return ss.str(); diff --git a/cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu b/cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu new file mode 100644 index 0000000000..1873d66124 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tensorrt_llm/kernels/sparseAttentionKernels.h" +#include + +namespace tensorrt_llm +{ +namespace kernels +{ +template +__global__ void gatherKvPageOffsetsKernel( + int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq] + int32_t* output_seq_lengths, // [num_head_kv, batch_size] + int32_t const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq] + int32_t const* seq_lengths, // [batch_size] + SparseAttentionParams const sparse_params, int32_t const batch_size, int32_t const tokens_per_page, + int32_t const max_num_pages_per_seq) +{ + // Each CUDA block processes one sequence from the batch for one head. + int32_t const head_idx = blockIdx.x; + int32_t const batch_idx = blockIdx.y; + if (batch_idx >= batch_size) + { + return; + } + + // Shared memory for reduction. + __shared__ typename cub::BlockReduce::TempStorage temp_storage; + + // Get the range of sparse indices and the sequence length. + int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx]; + int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1]; + int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size]; + int32_t const num_sparse_pages = end_offset - start_offset; + int32_t const original_seq_len = seq_lengths[batch_idx]; + + // Get global sparse index. + int32_t const sparse_idx_global = head_idx * total_pages + start_offset; + + // Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq] + size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq; + size_t const dst_base_offset = (size_t) head_idx * batch_size * 2 * max_num_pages_per_seq + src_base_offset; + + // Initialize the local max page index and number of valid pages. + int32_t local_max_page_index = -1; + int32_t local_num_valid_pages = 0; + + // Perform the gather operation. + for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x) + { + // Get the source idx and offset. + int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i]; + if (src_idx < 0) + { + continue; + } + + // Update the local max page index. + local_max_page_index = max(local_max_page_index, src_idx); + local_num_valid_pages++; + + // Get the source and destination offsets. + size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx; + size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx; + size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i; + size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i; + + // Perform the gather operation: read from the sparse location and write to the dense location. + output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0]; + output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1]; + } + + // Reduce the local max page indices and number of valid pages. + Pair local_pair = {local_max_page_index, local_num_valid_pages}; + Pair result = cub::BlockReduce(temp_storage).Reduce(local_pair, PairReduceOp()); + + // Update sequence length for this head and batch. + if (threadIdx.x == 0) + { + int32_t const max_page_index = result.max_val; + int32_t const num_valid_pages = result.sum_val; + int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page; + size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx; + if (num_valid_pages > 0) + { + int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page; + int32_t seq_len_remain = original_seq_len % tokens_per_page; + if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0) + { + seq_len += tokens_per_page - seq_len_remain; + } + output_seq_lengths[seq_len_offset] = seq_len; + } + else + { + output_seq_lengths[seq_len_offset] = 0; + } + } +} + +// Host-side launcher function +void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_seq_lengths, + int32_t const* kv_page_offsets, int32_t const* seq_lengths, SparseAttentionParams const sparse_params, + int32_t const batch_size, int32_t const num_head_kv, int32_t const tokens_per_page, + int32_t const max_num_pages_per_seq, cudaStream_t stream) +{ + // The grid. + dim3 grid(num_head_kv, batch_size, 1); + // The block. + dim3 block(256, 1, 1); + // Shared memory size. + size_t smem_size = sizeof(Pair) * 256; + + // Launch the kernel. + gatherKvPageOffsetsKernel<256><<>>(output_kv_page_offsets, output_seq_lengths, + kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq); +} +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/sparseAttentionKernels.h b/cpp/tensorrt_llm/kernels/sparseAttentionKernels.h new file mode 100644 index 0000000000..8d5b9c9ec1 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/sparseAttentionKernels.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace kernels +{ + +struct SparseAttentionParams +{ + int32_t* sparse_kv_indices{nullptr}; // [num_kv_heads, num_sparse_kv_indices] + int32_t* sparse_attn_indices{nullptr}; // [num_kv_heads, num_sparse_attn_indices] + int32_t* sparse_kv_offsets{nullptr}; // [num_contexts + 1] + int32_t* sparse_attn_offsets{nullptr}; // [num_generations + 1] + + std::string toString() const + { + std::stringstream ss; + ss << "sparse_kv_indices: " << this->sparse_kv_indices << std::endl + << "sparse_attn_indices: " << this->sparse_attn_indices << std::endl + << "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl + << "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl; + return ss.str(); + } + + auto data() const + { + return std::make_tuple(sparse_kv_indices, sparse_attn_indices, sparse_kv_offsets, sparse_attn_offsets); + } +}; + +struct Pair +{ + int32_t max_val; + int32_t sum_val; +}; + +struct PairReduceOp +{ +#if defined(__CUDACC__) + inline __device__ +#endif + Pair + operator()(Pair const& a, Pair const& b) const + { + Pair result; + result.max_val = a.max_val > b.max_val ? a.max_val : b.max_val; + result.sum_val = a.sum_val + b.sum_val; + return result; + } +}; + +void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq] + int32_t* output_seq_lengths, // [num_head_kv, batch_size] + int32_t const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq] + int32_t const* seq_lengths, // [batch_size] + SparseAttentionParams const sparse_params, int32_t const batch_size, int32_t const num_head_kv, + int32_t const tokens_per_page, int32_t const max_num_pages_per_seq, cudaStream_t stream); + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h index 4f4c1e72f0..a297db78f0 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h @@ -207,8 +207,6 @@ public: // Prepare the kernel parameters. auto kernelParams = KernelParams::setKernelParams(params, kernelMeta, maxNumCtasQ, maxNumCtasKv); - // TODO: set the block sparse attention flag. - kernelParams.mUseBlockSparseAttention = false; // Prepare kernel parameters list for cuLaunchKernelEx. void* kernelParamsList[] = {&kernelParams}; diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h index 22fa7d464e..57fd40b78c 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h @@ -148,6 +148,10 @@ struct QKVPreprocessingParams int const* cu_seq_lens{nullptr}; // list of cumulative KV sequence lengths, of shape {batch_size + 1}, used by cross attention only. int const* cu_kv_seq_lens{nullptr}; + // list of cumulative length of sparse KV indices, of shape {batch_size + 1} + int const* sparse_kv_offsets{nullptr}; + // list of sparse KV indices for writing to KV cache, of shape {num_kv_heads, num_sparse_kv_indices} + int const* sparse_kv_indices{nullptr}; // inverse frequencies (angle raised at various powers) from the RoPE formula // shape of {batch_size , rotaryEmbeddingDim / 2} float const* rotary_embedding_inv_freq{nullptr}; @@ -167,6 +171,7 @@ struct QKVPreprocessingParams int sink_token_len{0}; int token_num{0}; bool remove_padding{true}; + bool is_last_chunk{true}; bool cross_attention{false}; int head_num{0}; int kv_head_num{0}; @@ -216,24 +221,48 @@ struct QKVPreprocessingParams ss << "kv_cache_block_scales_buffer: " << kv_cache_block_scales_buffer.data << std::endl; ss << "qkv_bias: " << qkv_bias << std::endl; ss << "tokens_info: " << tokens_info << std::endl; - ss << "seq_lens: " - << *(runtime::ITensor::wrap( - (void*) seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); - ss << "cache_seq_lens: " - << *(runtime::ITensor::wrap( - (void*) cache_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); - ss << "encoder_seq_lens: " - << *(runtime::ITensor::wrap( - (void*) encoder_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); - ss << "cu_seq_lens: " - << *(runtime::ITensor::wrap( - (void*) cu_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); - ss << "cu_kv_seq_lens: " - << *(runtime::ITensor::wrap( - (void*) cu_kv_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); - ss << "rotary_embedding_inv_freq: " - << *(runtime::ITensor::wrap((void*) rotary_embedding_inv_freq, nvinfer1::DataType::kFLOAT, - runtime::ITensor::makeShape({batch_size, rotary_embedding_dim / 2}))); + if (seq_lens && batch_size > 0) + { + ss << "seq_lens: " + << *(runtime::ITensor::wrap( + (void*) seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); + } + if (cache_seq_lens && batch_size > 0) + { + ss << "cache_seq_lens: " + << *(runtime::ITensor::wrap( + (void*) cache_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); + } + if (encoder_seq_lens && batch_size > 0) + { + ss << "encoder_seq_lens: " + << *(runtime::ITensor::wrap( + (void*) encoder_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); + } + if (cu_seq_lens && batch_size > 0) + { + ss << "cu_seq_lens: " + << *(runtime::ITensor::wrap( + (void*) cu_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); + } + if (cu_kv_seq_lens && batch_size > 0) + { + ss << "cu_kv_seq_lens: " + << *(runtime::ITensor::wrap( + (void*) cu_kv_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size}))); + } + if (sparse_kv_offsets) + { + ss << "sparse_kv_offsets: " + << *(runtime::ITensor::wrap((void*) sparse_kv_offsets, nvinfer1::DataType::kINT32, + runtime::ITensor::makeShape({batch_size + 1}))); + } + if (rotary_embedding_inv_freq && batch_size > 0 && rotary_embedding_dim > 0) + { + ss << "rotary_embedding_inv_freq: " + << *(runtime::ITensor::wrap((void*) rotary_embedding_inv_freq, nvinfer1::DataType::kFLOAT, + runtime::ITensor::makeShape({batch_size, rotary_embedding_dim / 2}))); + } ss << "rotary_coef_cache_buffer: " << rotary_coef_cache_buffer << std::endl; ss << "qkv_scale_orig_quant: " << qkv_scale_orig_quant << std::endl; ss << "spec_decoding_position_offsets: " << spec_decoding_position_offsets << std::endl; @@ -244,6 +273,7 @@ struct QKVPreprocessingParams ss << "sink_token_len: " << sink_token_len << std::endl; ss << "token_num: " << token_num << std::endl; ss << "remove_padding: " << remove_padding << std::endl; + ss << "is_last_chunk: " << is_last_chunk << std::endl; ss << "cross_attention: " << cross_attention << std::endl; ss << "head_num: " << head_num << std::endl; ss << "kv_head_num: " << kv_head_num << std::endl; @@ -362,23 +392,36 @@ void invokeQKVPreprocessing(QKVPreprocessingParams params, cud template void invokeUpdateCyclicKvCacheAfterFmha(QKVPreprocessingParams params, cudaStream_t stream); +template +void invokeUpdateSparseKvCacheAfterFmha(QKVPreprocessingParams params, cudaStream_t stream); + +// Debug function to test basic parameter access +template +void invokeDebugSparseKvCacheParams( + QKVPreprocessingParams params, int* debug_output, cudaStream_t stream); + template void invokeKvCachePostprocessing(QKVPreprocessingParams params, cudaStream_t stream) { params.setCommonParameters(); - if (params.cache_type == KvCacheDataType::INT8) + + // handle sparse KV cache update if needed + if (params.sparse_kv_indices != nullptr && params.sparse_kv_offsets != nullptr && params.is_last_chunk) { - invokeUpdateCyclicKvCacheAfterFmha(params, stream); - } + if (params.cache_type == KvCacheDataType::INT8) + { + invokeUpdateSparseKvCacheAfterFmha(params, stream); + } #ifdef ENABLE_FP8 - else if (params.cache_type == KvCacheDataType::FP8) - { - invokeUpdateCyclicKvCacheAfterFmha(params, stream); - } + else if (params.cache_type == KvCacheDataType::FP8) + { + invokeUpdateSparseKvCacheAfterFmha(params, stream); + } #endif // ENABLE_FP8 - else - { - invokeUpdateCyclicKvCacheAfterFmha(params, stream); + else + { + invokeUpdateSparseKvCacheAfterFmha(params, stream); + } } } diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h index fb86622031..abe76a5902 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h @@ -1709,6 +1709,130 @@ void invokeUpdateCyclicKvCacheAfterFmha(QKVPreprocessingParams //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__global__ __launch_bounds__(BLOCK_SIZE) void updateSparseKvCacheAfterFmha( + QKVPreprocessingParams params) +{ + // The number of 16B vectors per head size in the kv cache. + constexpr int VECS_PER_HEAD = Dh * sizeof(TCache) / 16; + static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads."); + + int const batch_idx = blockIdx.z; + int const kv_head_idx = blockIdx.y; + + int const total_num_sparse_kv_tokens = params.sparse_kv_offsets[params.batch_size]; + + int const sparse_start_idx = params.sparse_kv_offsets[batch_idx]; + int const sparse_end_idx = params.sparse_kv_offsets[batch_idx + 1]; + int const num_sparse_tokens = sparse_end_idx - sparse_start_idx; + + int const tokens_per_block = blockDim.y; + int const vecs_per_block = blockDim.x; + + extern __shared__ uint4 smem[]; + uint4* k_smem = smem; + uint4* v_smem = k_smem + tokens_per_block * VECS_PER_HEAD; + + for (int token_block_offset = 0; token_block_offset < num_sparse_tokens; token_block_offset += tokens_per_block) + { + int const sparse_token_offset = token_block_offset + threadIdx.y; + + if (sparse_token_offset < num_sparse_tokens) + { + int const global_sparse_idx = sparse_start_idx + sparse_token_offset; + int const sparse_idx_offset = kv_head_idx * total_num_sparse_kv_tokens + global_sparse_idx; + + int const src_token_idx = params.sparse_kv_indices[sparse_idx_offset]; + + void* src_k_ptr = params.kv_cache_buffer.getKBlockPtr(batch_idx, src_token_idx); + void* src_v_ptr = params.kv_cache_buffer.getVBlockPtr(batch_idx, src_token_idx); + auto const src_k_block_ptr = reinterpret_cast(src_k_ptr); + auto const src_v_block_ptr = reinterpret_cast(src_v_ptr); + + for (int head_vec_idx = threadIdx.x; head_vec_idx < VECS_PER_HEAD; head_vec_idx += vecs_per_block) + { + auto const src_k_vec_idx + = params.kv_cache_buffer.getKVLocalIdx(src_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx); + auto const src_v_vec_idx + = params.kv_cache_buffer.getKVLocalIdx(src_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx); + + k_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx] = src_k_block_ptr[src_k_vec_idx]; + v_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx] = src_v_block_ptr[src_v_vec_idx]; + } + } + __syncthreads(); + + if (sparse_token_offset < num_sparse_tokens) + { + int const global_sparse_idx = sparse_start_idx + sparse_token_offset; + int const sparse_idx_offset = kv_head_idx * total_num_sparse_kv_tokens + global_sparse_idx; + + int const src_token_idx = params.sparse_kv_indices[sparse_idx_offset]; + int const dst_token_idx = sparse_token_offset; + + if (src_token_idx != dst_token_idx) + { + void* dst_k_ptr = params.kv_cache_buffer.getKBlockPtr(batch_idx, dst_token_idx); + void* dst_v_ptr = params.kv_cache_buffer.getVBlockPtr(batch_idx, dst_token_idx); + auto const dst_k_block_ptr = reinterpret_cast(dst_k_ptr); + auto const dst_v_block_ptr = reinterpret_cast(dst_v_ptr); + + for (int head_vec_idx = threadIdx.x; head_vec_idx < VECS_PER_HEAD; head_vec_idx += vecs_per_block) + { + auto const dst_k_vec_idx + = params.kv_cache_buffer.getKVLocalIdx(dst_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx); + auto const dst_v_vec_idx + = params.kv_cache_buffer.getKVLocalIdx(dst_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx); + dst_k_block_ptr[dst_k_vec_idx] = k_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx]; + dst_v_block_ptr[dst_v_vec_idx] = v_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx]; + } + } + } + __syncthreads(); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void kernelSparseDispatchHeadSize(QKVPreprocessingParams params, cudaStream_t stream) +{ + constexpr int VECS_PER_HEAD = Dh * sizeof(TCache) / 16; + constexpr int BLOCK_SIZE = 1024; + dim3 block(32, 32); // x: head vectors, y: tokens + + int smem_size = 2 * block.y * VECS_PER_HEAD * sizeof(uint4); + + // grid.x is always 1 to avoid data races + dim3 grid(1, params.kv_head_num, params.batch_size); + + updateSparseKvCacheAfterFmha<<>>(params); +} + +template +void invokeUpdateSparseKvCacheAfterFmha(QKVPreprocessingParams params, cudaStream_t stream) +{ + if (params.sparse_kv_indices == nullptr) + { + return; + } + + switch (params.size_per_head) + { + case 16: kernelSparseDispatchHeadSize<16, T, TCache, KVCacheBuffer>(params, stream); break; + case 32: kernelSparseDispatchHeadSize<32, T, TCache, KVCacheBuffer>(params, stream); break; + case 64: kernelSparseDispatchHeadSize<64, T, TCache, KVCacheBuffer>(params, stream); break; + case 128: kernelSparseDispatchHeadSize<128, T, TCache, KVCacheBuffer>(params, stream); break; + case 256: kernelSparseDispatchHeadSize<256, T, TCache, KVCacheBuffer>(params, stream); break; + default: + TLLM_CHECK_WITH_INFO( + false, "updateSparseKvCacheAfterFmha kernel doesn't support head size = %d", params.size_per_head); + break; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + #define INSTANTIATE_ATTENTION_INPUT_PROCESSING(T, TCache, KVCacheBuffer) \ template void invokeApplyBiasRopeUpdateKVCacheDispatch( \ QKVPreprocessingParams params, cudaStream_t stream); @@ -1717,9 +1841,10 @@ void invokeUpdateCyclicKvCacheAfterFmha(QKVPreprocessingParams template void invokeApplyBiasRopeUpdateKVCacheDispatch( \ QKVPreprocessingParams params, cudaStream_t stream); \ template void invokeUpdateCyclicKvCacheAfterFmha( \ - QKVPreprocessingParams params, cudaStream_t stream); - -//////////////////////////////////////////////////////////////////////////////////////////////////// + QKVPreprocessingParams params, cudaStream_t stream); \ + template void invokeUpdateSparseKvCacheAfterFmha( \ + QKVPreprocessingParams params, cudaStream_t stream); \ + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/xqaDispatcher.cpp b/cpp/tensorrt_llm/kernels/xqaDispatcher.cpp index 20b1dbbe4e..dda7293747 100644 --- a/cpp/tensorrt_llm/kernels/xqaDispatcher.cpp +++ b/cpp/tensorrt_llm/kernels/xqaDispatcher.cpp @@ -17,6 +17,7 @@ #include "xqaDispatcher.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h" +#include "tensorrt_llm/kernels/sparseAttentionKernels.h" #include "tensorrt_llm/kernels/unfusedAttentionKernels.h" #include @@ -404,9 +405,15 @@ void XqaDispatcher::runImpl( // Otherwise, always enable the persistent scheduler for better performance. tllmRunnerParams.mTileScheduler = params.multi_block_mode ? TileScheduler::Static : TileScheduler::Persistent; + // The sequence lengths for K/V. + tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths; + // Q buffer. tllmRunnerParams.qPtr = xqa_q_input_ptr; - // KV buffer + + // Use block sparse attention. + tllmRunnerParams.mUseBlockSparseAttention = false; + if constexpr (std::is_same_v) { // Paged KV @@ -416,9 +423,24 @@ void XqaDispatcher::runImpl( tllmRunnerParams.kvPageIdxPtr = reinterpret_cast(kv_cache_buffer.data); tllmRunnerParams.mMaxNumPagesPerSeqKv = kv_cache_buffer.mMaxBlocksPerSeq; tllmRunnerParams.mNumTokensPerPage = kv_cache_buffer.mTokensPerBlock; + + // Gather kv page offsets for sparse attention. + if (params.use_sparse_attention) + { + invokeGatherKvPageOffsets(reinterpret_cast(launchParams.sparse_kv_block_offsets), + launchParams.sparse_seq_lengths, reinterpret_cast(kv_cache_buffer.data), + params.sequence_lengths, params.sparse_params, batch_beam_size, num_kv_heads, + kv_cache_buffer.mTokensPerBlock, kv_cache_buffer.mMaxBlocksPerSeq, params.stream); + sync_check_cuda_error(params.stream); + tllmRunnerParams.seqLensKvPtr = launchParams.sparse_seq_lengths; + tllmRunnerParams.kvPageIdxPtr + = reinterpret_cast(launchParams.sparse_kv_block_offsets); + tllmRunnerParams.mUseBlockSparseAttention = true; + } } else { + TLLM_CHECK_WITH_INFO(!params.use_sparse_attention, "Sparse attention is not supported for KVLinearBuffer."); static_assert(std::is_same_v); // Contiguous KV tllmRunnerParams.mQkvLayout = QkvLayout::ContiguousKv; @@ -437,8 +459,6 @@ void XqaDispatcher::runImpl( tllmRunnerParams.scaleSoftmaxLog2Ptr = reinterpret_cast(launchParams.bmm1_scale_ptr + kIdxScaleSoftmaxLog2Ptr); tllmRunnerParams.oSfScalePtr = params.fp4_out_sf_scale; - // The sequence lengths for K/V. - tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths; tllmRunnerParams.oPtr = params.output; tllmRunnerParams.oSfPtr = params.output_sf; diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp index 1b86dc97fa..065b93e561 100644 --- a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp @@ -55,7 +55,7 @@ void initBindings(nb::module_& m) nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt, nb::arg("mla_tensor_params"), nb::arg("attention_chunk_size") = std::nullopt, nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"), - nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation", + nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_attention_params"), "Multi-head attention operation", nb::call_guard()); } } // namespace tensorrt_llm::nanobind::thop diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index 861b9332dd..34f97b9ad6 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -572,11 +572,14 @@ size_t GPTAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* in = isCrossAttention() ? cross_kv_length : (useKVCache() ? inputs[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0); int const max_num_tokens = mRemovePadding ? inputs[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0] : max_num_seq * max_context_length; + int const max_blocks_per_sequence + = (useKVCache() && mPagedKVCache) ? inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims.d[3] : 0; + size_t const context_workspace_size = getWorkspaceSizeForContext(type, max_num_seq, max_context_length, cross_kv_length, max_num_tokens); - size_t const generation_workspace_size - = getWorkspaceSizeForGeneration(type, max_num_seq, max_kv_cache_length, max_num_tokens); + size_t const generation_workspace_size = getWorkspaceSizeForGeneration( + type, max_num_seq, max_kv_cache_length, max_num_tokens, max_blocks_per_sequence); size_t attention_input_workspace_size = 0; diff --git a/cpp/tensorrt_llm/pybind/thop/bindings.cpp b/cpp/tensorrt_llm/pybind/thop/bindings.cpp index 6dbfeb2847..29181b7273 100644 --- a/cpp/tensorrt_llm/pybind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/thop/bindings.cpp @@ -55,7 +55,7 @@ void initBindings(pybind11::module_& m) py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt, py::arg("mla_tensor_params"), py::arg("attention_chunk_size") = std::nullopt, py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"), - py::arg("spec_decoding_tensor_params"), "Multi-head attention operation", + py::arg("spec_decoding_tensor_params"), py::arg("sparse_attention_params"), "Multi-head attention operation", py::call_guard()); } } // namespace tensorrt_llm::pybind::thop diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index d6a64d733b..c5c04aceb7 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/dataType.h" #include "tensorrt_llm/kernels/gptKernels.h" #include "tensorrt_llm/kernels/mlaKernels.h" +#include "tensorrt_llm/kernels/sparseAttentionKernels.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/runtime/utils/debugUtils.h" #include "tensorrt_llm/thop/attentionOp.h" @@ -38,6 +39,7 @@ namespace trtllm::attention { using tensorrt_llm::kernels::KVBlockArray; using tensorrt_llm::kernels::MlaParams; +using tensorrt_llm::kernels::SparseAttentionParams; enum class AttentionInputType : int8_t { @@ -62,7 +64,7 @@ public: virtual ~RunnerBase() = default; virtual void prepare(AttentionOp& op) const = 0; virtual int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size, - int const num_gen_tokens) const + int const num_gen_tokens, int const max_blocks_per_sequence) const = 0; // typically, we use single qkv input, but for context MLA, we use separate qkv inputs virtual void run(AttentionOp& op, bool const is_context, int32_t const seq_offset, int32_t const num_seqs, @@ -82,7 +84,9 @@ public: std::vector> mla_tensor_params, torch::optional softmax_stats_tensor, c10::ArrayRef> spec_decoding_tensor_params, - torch::optional attention_sinks) const + torch::optional attention_sinks, torch::optional sparse_kv_indices, + torch::optional sparse_kv_offsets, torch::optional sparse_attn_indices, + torch::optional sparse_attn_offsets) const = 0; }; @@ -110,12 +114,12 @@ public: } int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size, - int const num_gen_tokens) const override + int const num_gen_tokens, int const max_blocks_per_sequence) const override { size_t const context_workspace_size = op.getWorkspaceSizeForContext(op.mType, max_num_requests, op.mMaxContextLength, 0, num_tokens); - size_t const generation_workspace_size - = op.getWorkspaceSizeForGeneration(op.mType, max_num_requests, max_attention_window_size, num_gen_tokens); + size_t const generation_workspace_size = op.getWorkspaceSizeForGeneration( + op.mType, max_num_requests, max_attention_window_size, num_gen_tokens, max_blocks_per_sequence); return std::max(context_workspace_size, generation_workspace_size); } @@ -137,7 +141,9 @@ public: std::vector> mla_tensor_params, torch::optional softmax_stats_tensor, c10::ArrayRef> spec_decoding_tensor_params, - torch::optional attention_sinks) const override + torch::optional attention_sinks, torch::optional sparse_kv_indices, + torch::optional sparse_kv_offsets, torch::optional sparse_attn_indices, + torch::optional sparse_attn_offsets) const override { auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device()); T* attention_input = static_cast(qkv_or_q.slice(0, token_offset).data_ptr()); @@ -221,6 +227,22 @@ public: } } + // Prepare sparse attention parameters + if (is_context) + { + op.mRuntimeSparseAttentionParams.sparse_kv_indices + = sparse_kv_indices.has_value() ? sparse_kv_indices.value().data_ptr() : nullptr; + op.mRuntimeSparseAttentionParams.sparse_kv_offsets + = sparse_kv_offsets.has_value() ? sparse_kv_offsets.value().data_ptr() : nullptr; + } + else + { + op.mRuntimeSparseAttentionParams.sparse_attn_indices + = sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr() : nullptr; + op.mRuntimeSparseAttentionParams.sparse_attn_offsets + = sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr() : nullptr; + } + int const* context_lengths_ptr = context_lengths.slice(0, seq_offset).data_ptr(); int const* sequence_lengths_ptr = sequence_length.slice(0, seq_offset).data_ptr(); // Note we still need context length during generation for MMHA optimization. @@ -518,8 +540,16 @@ void attention(torch::Tensor q, std::optional k, std::optional mrope_rotary_cos_sin, std::optional mrope_position_deltas, std::vector> mla_tensor_params, std::optional attention_chunk_size, std::optional softmax_stats_tensor, std::vector spec_decoding_bool_params, - std::vector> spec_decoding_tensor_params) + std::vector> spec_decoding_tensor_params, + std::vector> sparse_attention_params) { + // Decompress sparse attention parameters + TORCH_CHECK(sparse_attention_params.size() == 4, "Expected 4 sparse attention parameters"); + torch::optional sparse_kv_indices = sparse_attention_params[0]; + torch::optional sparse_kv_offsets = sparse_attention_params[1]; + torch::optional sparse_attn_indices = sparse_attention_params[2]; + torch::optional sparse_attn_offsets = sparse_attention_params[3]; + TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx); // Use these tensors to infer if the attention is using KV cache bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value() @@ -633,6 +663,18 @@ void attention(torch::Tensor q, std::optional k, std::optionalmUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding op->mIsSpecDecTree = spec_decoding_bool_params[2]; // is_spec_dec_tree + op->mUseSparseAttention = false; + op->mUseTllmGenSparseAttention = false; + if ((sparse_kv_indices.has_value() && sparse_kv_indices.value().numel() > 0) + || (sparse_attn_indices.has_value() && sparse_attn_indices.value().numel() > 0)) + { + op->mUseSparseAttention = true; + if (sparse_attn_indices.has_value() && sparse_attn_indices.value().numel() > 0) + { + op->mUseTllmGenSparseAttention = true; + } + } + if (is_mla_enable) { // MLA does not support NVFP4 output yet. @@ -718,7 +760,10 @@ void attention(torch::Tensor q, std::optional k, std::optionalgetWorkspaceSize(*op, num_tokens, max_attention_window_size, num_gen_tokens); + int32_t const max_blocks_per_sequence + = use_kv_cache && kv_cache_block_offsets.has_value() ? kv_cache_block_offsets.value().size(-1) : 0; + int64_t const workspace_size + = runner->getWorkspaceSize(*op, num_tokens, max_attention_window_size, num_gen_tokens, max_blocks_per_sequence); TLLM_LOG_TRACE("Expected workspace size is %ld bytes", workspace_size); if (workspace_size >= (16l << 30)) @@ -760,7 +805,7 @@ void attention(torch::Tensor q, std::optional k, std::optional 0) && (attn_input_type != AttentionInputType::ContextOnly)) @@ -777,7 +822,7 @@ void attention(torch::Tensor q, std::optional k, std::optional k, std::optional chunked_prefill_buffer_batch_size, std::optional q_lora_rank, std::optional kv_lora_rank, std::optional qk_nope_head_dim, std::optional qk_rope_head_dim, std::optional v_head_dim, - torch::optional mrope_rotary_cos_sin, torch::optional mrope_position_deltas, + std::optional mrope_rotary_cos_sin, std::optional mrope_position_deltas, std::vector> mla_tensor_params, std::optional attention_chunk_size, std::optional softmax_stats_tensor, std::vector spec_decoding_bool_params, - std::vector> spec_decoding_tensor_params); + std::vector> spec_decoding_tensor_params, + std::vector> sparse_attention_params); } // namespace torch_ext diff --git a/cpp/tests/unit_tests/kernels/CMakeLists.txt b/cpp/tests/unit_tests/kernels/CMakeLists.txt index 82f51afd82..2574df960e 100644 --- a/cpp/tests/unit_tests/kernels/CMakeLists.txt +++ b/cpp/tests/unit_tests/kernels/CMakeLists.txt @@ -40,6 +40,7 @@ add_gtest(smoothQuantKernelTest smoothQuant/smoothQuantKernelTest.cpp) add_gtest(stopCriteriaKernelsTest stopCriteriaKernelsTest.cpp) add_gtest(weightOnlyKernelTest weightOnly/weightOnlyKernelTest.cpp) add_gtest(mlaPreprocessTest mlaPreprocessTest.cu) +add_gtest(sparseAttentionKernelsTest sparseAttentionKernelsTest.cpp) add_gtest(cudaCoreGemmKernelTest cudaCoreGemm/cudaCoreGemmKernelTest.cpp) @@ -90,3 +91,4 @@ add_gtest(routingKernelsTest "${ROUTING_KERNEL_TEST_SRC}") add_gtest(moeLoadBalanceKernelTest moeLoadBalanceKernelTest.cpp) add_gtest(eaglePackDataTest eaglePackDataTest.cpp) +add_gtest(sparseKvCacheTest sparseKvCacheTest.cu) diff --git a/cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp b/cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp new file mode 100644 index 0000000000..fb6da1bc22 --- /dev/null +++ b/cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp @@ -0,0 +1,191 @@ +#include + +#include "tensorrt_llm/kernels/sparseAttentionKernels.h" +#include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/cudaStream.h" + +#include +#include + +using namespace tensorrt_llm::kernels; +using namespace tensorrt_llm::runtime; + +namespace +{ +class sparseAttentionKernelsTest : public ::testing::Test +{ +public: + void SetUp() override + { + mStream = std::make_shared(); + mBufferManager = std::make_shared(mStream); + } + + void TearDown() override {} + +protected: + std::shared_ptr mStream; + std::shared_ptr mBufferManager; +}; + +TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest) +{ + // Test parameters + constexpr int max_batch_size = 4; + constexpr int batch_size = 2; + constexpr int num_head_kv = 4; + constexpr int max_num_pages_per_seq = 8; + constexpr int tokens_per_page = 64; + constexpr int total_sparse_pages = max_batch_size * max_num_pages_per_seq; // Total sparse pages across all batches + + // Create input buffers + auto kv_page_offsets + = mBufferManager->gpu(ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32); + auto seq_lengths = mBufferManager->gpu(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32); + auto sparse_indices + = mBufferManager->gpu(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32); + auto sparse_indices_offsets = mBufferManager->gpu(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32); + + // Create output buffers + auto output_kv_page_offsets = mBufferManager->gpu( + ITensor::makeShape({num_head_kv, batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32); + auto output_seq_lengths + = mBufferManager->gpu(ITensor::makeShape({num_head_kv, batch_size}), nvinfer1::DataType::kINT32); + + // Create pinned host buffers for data initialization + auto kv_page_offsets_host = mBufferManager->pinned( + ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32); + auto seq_lengths_host = mBufferManager->pinned(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32); + auto sparse_indices_host + = mBufferManager->pinned(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32); + auto sparse_indices_offsets_host + = mBufferManager->pinned(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32); + + // Initialize test data + auto kv_page_offsets_ptr = bufferCast(*kv_page_offsets_host); + auto seq_lengths_ptr = bufferCast(*seq_lengths_host); + auto sparse_indices_ptr = bufferCast(*sparse_indices_host); + auto sparse_indices_offsets_ptr = bufferCast(*sparse_indices_offsets_host); + + // Initialize KV page offsets with test data + for (int b = 0; b < batch_size; ++b) + { + for (int d = 0; d < 2; ++d) + { + for (int p = 0; p < max_num_pages_per_seq; ++p) + { + int offset = b * 2 * max_num_pages_per_seq + d * max_num_pages_per_seq + p; + kv_page_offsets_ptr[offset] = 1000 + b * 100 + d * 10 + p; + } + } + } + + // Initialize sequence lengths + seq_lengths_ptr[0] = 2 * tokens_per_page + 18; // 3 pages for batch 0 + seq_lengths_ptr[1] = 3 * tokens_per_page + 3; // 4 pages for batch 1 + + // Initialize sparse indices with different patterns for different heads + // Shape: {total_sparse_pages, num_head_kv} + // Each head can have its own sparse pattern + int num_sparse_pages = 5; + int sparse_page_indices[5][4] = {{1, 0, 0, 1}, {2, 1, 1, -1}, {-1, 2, -1, -1}, {0, 1, 2, 3}, {3, 3, 3, -1}}; + for (int page = 0; page < num_sparse_pages; ++page) + { + for (int head = 0; head < num_head_kv; ++head) + { + int idx = head * num_sparse_pages + page; + sparse_indices_ptr[idx] = sparse_page_indices[page][head]; + } + } + + // Initialize sparse indices offsets + sparse_indices_offsets_ptr[0] = 0; // Start of batch 0 + sparse_indices_offsets_ptr[1] = 3; // Start of batch 1 (3 sparse pages for batch 0) + sparse_indices_offsets_ptr[2] = 5; // End (3 sparse pages for batch 1) + + // Copy data to GPU + mBufferManager->copy(*kv_page_offsets_host, *kv_page_offsets); + mBufferManager->copy(*seq_lengths_host, *seq_lengths); + mBufferManager->copy(*sparse_indices_host, *sparse_indices); + mBufferManager->copy(*sparse_indices_offsets_host, *sparse_indices_offsets); + + SparseAttentionParams sparse_params; + sparse_params.sparse_attn_indices = bufferCast(*sparse_indices); + sparse_params.sparse_attn_offsets = bufferCast(*sparse_indices_offsets); + + // Launch the kernel + invokeGatherKvPageOffsets(bufferCast(*output_kv_page_offsets), bufferCast(*output_seq_lengths), + bufferCast(*kv_page_offsets), bufferCast(*seq_lengths), sparse_params, batch_size, + num_head_kv, tokens_per_page, max_num_pages_per_seq, mStream->get()); + + // Wait for completion + mStream->synchronize(); + + // Copy results back to host for verification + auto output_kv_page_offsets_host = mBufferManager->pinned( + ITensor::makeShape({num_head_kv, batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32); + auto output_seq_lengths_host + = mBufferManager->pinned(ITensor::makeShape({num_head_kv, batch_size}), nvinfer1::DataType::kINT32); + + mBufferManager->copy(*output_kv_page_offsets, *output_kv_page_offsets_host); + mBufferManager->copy(*output_seq_lengths, *output_seq_lengths_host); + + // Wait for completion + mStream->synchronize(); + + auto output_kv_offsets_ptr = bufferCast(*output_kv_page_offsets_host); + auto output_seq_len_ptr = bufferCast(*output_seq_lengths_host); + + // Verify sequence lengths for each head and batch + int expected_seq_lens[4][2] = { + {tokens_per_page + 18, tokens_per_page + 3}, // Head 0: batch 0 has 2 pages, batch 1 has 0 pages + {2 * tokens_per_page + 18, tokens_per_page + 3}, // Head 1: batch 0 has 3 pages, batch 1 has 0 pages + {2 * tokens_per_page, tokens_per_page + 3}, // Head 2: batch 0 has 2 pages, batch 1 has 0 pages + {tokens_per_page, 3} // Head 3: batch 0 has 2 pages, batch 1 has 0 pages + }; + + for (int h = 0; h < num_head_kv; ++h) + { + for (int b = 0; b < batch_size; ++b) + { + int seq_len_idx = h * batch_size + b; + EXPECT_EQ(output_seq_len_ptr[seq_len_idx], expected_seq_lens[h][b]) + << "Sequence length mismatch at head=" << h << ", batch=" << b; + } + } + + // Verify gathered KV page offsets + for (int h = 0; h < num_head_kv; ++h) + { + for (int b = 0; b < batch_size; ++b) + { + int num_sparse_pages_batch = sparse_indices_offsets_ptr[b + 1] - sparse_indices_offsets_ptr[b]; + for (int d = 0; d < 2; ++d) + { + for (int p = 0; p < num_sparse_pages_batch; ++p) + { + // Calculate expected value (from the sparse index) + int sparse_idx_global = sparse_indices_offsets_ptr[b] + p; + int src_page_idx + = sparse_indices_ptr[h * sparse_indices_offsets_ptr[batch_size] + sparse_idx_global]; + + if (src_page_idx == -1) + { + continue; // Skip invalid indices + } + + // Calculate output offset + size_t output_offset = h * batch_size * 2 * max_num_pages_per_seq + b * 2 * max_num_pages_per_seq + + d * max_num_pages_per_seq + p; + + int expected_value = 1000 + b * 100 + d * 10 + src_page_idx; + + EXPECT_EQ(output_kv_offsets_ptr[output_offset], expected_value) + << "Mismatch at head=" << h << ", batch=" << b << ", dim=" << d << ", page=" << p; + } + } + } + } +} + +} // namespace diff --git a/cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu b/cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu new file mode 100644 index 0000000000..f1e27f5028 --- /dev/null +++ b/cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu @@ -0,0 +1,521 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/kvCacheUtils.h" +#include "tensorrt_llm/kernels/unfusedAttentionKernels.h" + +using namespace tensorrt_llm::kernels; + +class SparseKvCacheTest : public ::testing::Test +{ +protected: + void SetUp() override + { + mStream = nullptr; + TLLM_CUDA_CHECK(cudaStreamCreate(&mStream)); + + // Test parameters + mBatchSize = 2; + mNumKvHeads = 4; + mHeadSize = 128; + mMaxSeqLen = 512; + mTokensPerBlock = 64; + mMaxBlocksPerSeq = 8; + + // Allocate test data + setupTestData(); + } + + void TearDown() override + { + cleanup(); + if (mStream) + { + TLLM_CUDA_CHECK(cudaStreamDestroy(mStream)); + } + } + + void setupTestData() + { + // Allocate device memory for sparse KV indices and offsets + size_t sparse_indices_size = 20 * mNumKvHeads * sizeof(int); // 20 sparse tokens max + size_t sparse_offsets_size = (mBatchSize + 1) * sizeof(int); + + TLLM_CUDA_CHECK(cudaMalloc(&mSparseKvIndicesDevice, sparse_indices_size)); + TLLM_CUDA_CHECK(cudaMalloc(&mSparseKvOffsetsDevice, sparse_offsets_size)); + TLLM_CUDA_CHECK(cudaMalloc(&mSeqLensDevice, mBatchSize * sizeof(int))); + TLLM_CUDA_CHECK(cudaMalloc(&mCacheSeqLensDevice, mBatchSize * sizeof(int))); + + // Create sparse indices in the correct format: [sparse_token_idx][head_idx] + // Total sparse tokens: 5 (batch 0) + 3 (batch 1) = 8 + std::vector sparseKvIndicesHost; + + // Batch 0: 5 sparse tokens per head + std::vector> batch0_indices = { + {1, 2, 3, 4, 5}, // head 0 + {3, 4, 5, 6, 8}, // head 1 + {0, 1, 3, 5, 8}, // head 2 + {1, 3, 5, 10, 11} // head 3 + }; + + // Batch 1: 3 sparse tokens per head + std::vector> batch1_indices = { + {1, 4, 7}, // head 0 + {0, 2, 3}, // head 1 + {1, 2, 7}, // head 2 + {1, 3, 7} // head 3 + }; + + // [num_kv_heads, num_sparse_kv_tokens] + for (int head = 0; head < mNumKvHeads; ++head) + { + for (size_t token = 0; token < batch0_indices[head].size(); ++token) + { + sparseKvIndicesHost.push_back(batch0_indices[head][token]); + } + for (size_t token = 0; token < batch1_indices[head].size(); ++token) + { + sparseKvIndicesHost.push_back(batch1_indices[head][token]); + } + } + + std::vector sparseKvOffsetsHost = {0, 5, 8}; // Batch 0: 5 tokens, Batch 1: 3 tokens + std::vector seqLensHost = {12, 8}; // Original sequence lengths + std::vector cacheSeqLensHost = {12, 8}; // Cache sequence lengths + + TLLM_CUDA_CHECK(cudaMemcpy(mSparseKvIndicesDevice, sparseKvIndicesHost.data(), + sparseKvIndicesHost.size() * sizeof(int), cudaMemcpyHostToDevice)); + TLLM_CUDA_CHECK(cudaMemcpy(mSparseKvOffsetsDevice, sparseKvOffsetsHost.data(), + sparseKvOffsetsHost.size() * sizeof(int), cudaMemcpyHostToDevice)); + TLLM_CUDA_CHECK( + cudaMemcpy(mSeqLensDevice, seqLensHost.data(), seqLensHost.size() * sizeof(int), cudaMemcpyHostToDevice)); + TLLM_CUDA_CHECK(cudaMemcpy(mCacheSeqLensDevice, cacheSeqLensHost.data(), cacheSeqLensHost.size() * sizeof(int), + cudaMemcpyHostToDevice)); + TLLM_CUDA_CHECK(cudaDeviceSynchronize()); + // Setup KV cache buffer using KVBlockArray + setupKvCacheBuffer(); + } + + void setupKvCacheBuffer() + { + // Calculate memory requirements + auto const elemSize = sizeof(half); + auto const sizePerToken = mNumKvHeads * mHeadSize * elemSize; + auto const bytesPerBlock = mTokensPerBlock * sizePerToken; + auto const totalBlocks = mBatchSize * mMaxBlocksPerSeq; + auto const poolSize = totalBlocks * bytesPerBlock * 2; // K and V + + // Allocate primary pool + TLLM_CUDA_CHECK(cudaMalloc(&mKvCachePool, poolSize)); + TLLM_CUDA_CHECK(cudaMemset(mKvCachePool, 0, poolSize)); + + // Allocate block offsets + size_t offsetsSize = mBatchSize * mMaxBlocksPerSeq * 2 * sizeof(KVCacheIndex); + TLLM_CUDA_CHECK(cudaMalloc(&mBlockOffsetsDevice, offsetsSize)); + + // Initialize block offsets (simple linear mapping for test) + std::vector blockOffsetsHost; + blockOffsetsHost.reserve(mBatchSize * mMaxBlocksPerSeq * 2); + + for (int batch = 0; batch < mBatchSize; ++batch) + { + for (int block = 0; block < mMaxBlocksPerSeq; ++block) + { + // K cache block offset + int kBlockIdx = batch * mMaxBlocksPerSeq * 2 + block; + blockOffsetsHost.emplace_back(kBlockIdx, false); + } + for (int block = 0; block < mMaxBlocksPerSeq; ++block) + { + // V cache block offset + int vBlockIdx = batch * mMaxBlocksPerSeq * 2 + mMaxBlocksPerSeq + block; + blockOffsetsHost.emplace_back(vBlockIdx, false); + } + } + + TLLM_CUDA_CHECK(cudaMemcpy(mBlockOffsetsDevice, blockOffsetsHost.data(), + blockOffsetsHost.size() * sizeof(KVCacheIndex), cudaMemcpyHostToDevice)); + + TLLM_CUDA_CHECK(cudaDeviceSynchronize()); + // Create KVBlockArray with correct parameter order: + // (batchSize, maxBlocksPerSeq, tokensPerBlock, bytesPerToken, + // maxAttentionWindow, maxAttentionWindowAllLayer, sinkTokenLen, canUseOneMoreBlock, + // primaryPoolPtr, secondaryPoolPtr, data) + mKvCacheBuffer = KVBlockArray(mBatchSize, mMaxBlocksPerSeq, mTokensPerBlock, sizePerToken, mMaxSeqLen, + mMaxSeqLen, 0, false, mKvCachePool, nullptr, mBlockOffsetsDevice); + } + + void cleanup() + { + if (mSparseKvIndicesDevice) + cudaFree(mSparseKvIndicesDevice); + if (mSparseKvOffsetsDevice) + cudaFree(mSparseKvOffsetsDevice); + if (mSeqLensDevice) + cudaFree(mSeqLensDevice); + if (mCacheSeqLensDevice) + cudaFree(mCacheSeqLensDevice); + if (mKvCachePool) + cudaFree(mKvCachePool); + if (mBlockOffsetsDevice) + cudaFree(mBlockOffsetsDevice); + if (mQkvInputDevice) + cudaFree(mQkvInputDevice); + } + + // Test parameters + int mBatchSize; + int mNumKvHeads; + int mHeadSize; + int mMaxSeqLen; + int mTokensPerBlock; + int mMaxBlocksPerSeq; + + // Device memory + int* mSparseKvIndicesDevice = nullptr; + int* mSparseKvOffsetsDevice = nullptr; + int* mSeqLensDevice = nullptr; + int* mCacheSeqLensDevice = nullptr; + void* mKvCachePool = nullptr; + KVCacheIndex* mBlockOffsetsDevice = nullptr; + half* mQkvInputDevice = nullptr; + + KVBlockArray mKvCacheBuffer; + cudaStream_t mStream; + + // Helper functions for verification + bool verifySparseKvCacheMapping(std::vector const& originalKvCache); + void performHostSparseMapping(std::vector const& originalKvCache, std::vector& expectedKvCache); + void extractKvCacheFromGpu(std::vector& kvCacheHost); + void initializeKvCacheWithPattern(); +}; + +TEST_F(SparseKvCacheTest, UpdateSparseKvCacheAfterFmha) +{ + // Allocate dummy QKV input (normally this would come from the attention computation) + size_t qkvInputSize = mBatchSize * mMaxSeqLen * 3 * mNumKvHeads * mHeadSize * sizeof(half); + TLLM_CUDA_CHECK(cudaMalloc(&mQkvInputDevice, qkvInputSize)); + + // Initialize with test pattern + std::vector qkvInputHost(qkvInputSize / sizeof(half)); + for (size_t i = 0; i < qkvInputHost.size(); ++i) + { + qkvInputHost[i] = half(float(i % 1000) / 100.0f); // Simple test pattern + } + TLLM_CUDA_CHECK(cudaMemcpy(mQkvInputDevice, qkvInputHost.data(), qkvInputSize, cudaMemcpyHostToDevice)); + + // Initialize KV cache with initial data pattern for testing + initializeKvCacheWithPattern(); + + // Extract the original KV cache data before kernel execution for verification + size_t totalKvElements = mBatchSize * mMaxSeqLen * mNumKvHeads * mHeadSize * 2; // K and V + std::vector originalKvCache(totalKvElements); + extractKvCacheFromGpu(originalKvCache); + + // Setup QKVPreprocessingParams + QKVPreprocessingParams params; + memset(¶ms, 0, sizeof(params)); + + params.qkv_input = mQkvInputDevice; + params.kv_cache_buffer = mKvCacheBuffer; + params.sparse_kv_indices = mSparseKvIndicesDevice; + params.sparse_kv_offsets = mSparseKvOffsetsDevice; + params.seq_lens = mSeqLensDevice; + params.cache_seq_lens = mCacheSeqLensDevice; + + params.batch_size = mBatchSize; + params.head_num = mNumKvHeads; // For Q heads, assuming same as KV heads for this test + params.kv_head_num = mNumKvHeads; + params.size_per_head = mHeadSize; + params.cache_type = KvCacheDataType::BASE; + params.rotary_embedding_dim = 0; // No rotary embedding for this test + + params.setCommonParameters(); + + // Verify sparse indices and offsets on host + std::vector hostSparseIndices(8 * mNumKvHeads); + TLLM_CUDA_CHECK(cudaMemcpy(hostSparseIndices.data(), mSparseKvIndicesDevice, hostSparseIndices.size() * sizeof(int), + cudaMemcpyDeviceToHost)); + + std::vector hostSparseOffsets(mBatchSize + 1); + TLLM_CUDA_CHECK(cudaMemcpy(hostSparseOffsets.data(), mSparseKvOffsetsDevice, hostSparseOffsets.size() * sizeof(int), + cudaMemcpyDeviceToHost)); + TLLM_CUDA_CHECK(cudaDeviceSynchronize()); + cudaError_t pre_kernel_error = cudaGetLastError(); + if (pre_kernel_error != cudaSuccess) + { + printf("Debug: CUDA error before kernel call: %s\n", cudaGetErrorString(pre_kernel_error)); + } + + invokeUpdateSparseKvCacheAfterFmha(params, mStream); + + TLLM_CUDA_CHECK(cudaStreamSynchronize(mStream)); + cudaError_t post_kernel_error = cudaGetLastError(); + if (post_kernel_error != cudaSuccess) + { + printf("Debug: CUDA error after kernel call: %s\n", cudaGetErrorString(post_kernel_error)); + } + else + { + printf("Debug: Kernel call completed, no immediate CUDA errors\n"); + } + + // Verification: Compare GPU result with CPU reference implementation + EXPECT_TRUE(verifySparseKvCacheMapping(originalKvCache)); +} + +// Implementation of verification functions +bool SparseKvCacheTest::verifySparseKvCacheMapping(std::vector const& originalKvCache) +{ + // Perform host-side sparse mapping to get expected result + size_t totalKvElements = originalKvCache.size(); + std::vector expectedKvCache{originalKvCache}; + performHostSparseMapping(originalKvCache, expectedKvCache); + + // Extract actual result from GPU after sparse kernel execution + std::vector actualKvCache(totalKvElements); + extractKvCacheFromGpu(actualKvCache); + + // Compare results with tolerance - only for valid sparse tokens + float const tolerance = 1e-5f; + bool passed = true; + int errorCount = 0; + int const maxErrorsToShow = 10; + size_t totalValidElements = 0; + + // Only compare the sparse tokens that should have been reorganized + std::vector sparseTokenCounts = {5, 3}; // Batch 0: 5 tokens, Batch 1: 3 tokens + + for (int batch = 0; batch < mBatchSize; ++batch) + { + int numSparseTokens = sparseTokenCounts[batch]; + + for (int kv = 0; kv < 2; ++kv) // K and V + { + for (int token = 0; token < numSparseTokens; ++token) + { + for (int head = 0; head < mNumKvHeads; ++head) + { + for (int dim = 0; dim < mHeadSize; ++dim) + { + // Calculate index in flat array + // Layout: [batch][kv][token][head][dim] + size_t idx = batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize) + + kv * (mMaxSeqLen * mNumKvHeads * mHeadSize) + token * (mNumKvHeads * mHeadSize) + + head * mHeadSize + dim; + + if (idx < totalKvElements) + { + float expected = float(expectedKvCache[idx]); + float actual = float(actualKvCache[idx]); + float diff = std::abs(expected - actual); + + if (diff > tolerance) + { + if (errorCount < maxErrorsToShow) + { + printf( + "Mismatch at batch=%d, kv=%d, token=%d, head=%d, dim=%d: expected %.6f, got " + "%.6f, diff %.6f\n", + batch, kv, token, head, dim, expected, actual, diff); + } + errorCount++; + passed = false; + } + totalValidElements++; + } + } + } + } + } + } + + if (errorCount > 0) + { + printf("Total errors: %d out of %zu valid sparse token elements\n", errorCount, totalValidElements); + } + else + { + printf("Verification passed: all %zu valid sparse token elements match within tolerance %.2e\n", + totalValidElements, tolerance); + } + + return passed; +} + +void SparseKvCacheTest::performHostSparseMapping( + std::vector const& originalKvCache, std::vector& expectedKvCache) +{ + // Host-side reference implementation of sparse KV cache mapping + // This is a naive but correct implementation for verification + + // Read sparse indices from GPU memory to get the actual data being used + std::vector hostSparseIndices(8 * mNumKvHeads); + TLLM_CUDA_CHECK(cudaMemcpy(hostSparseIndices.data(), mSparseKvIndicesDevice, hostSparseIndices.size() * sizeof(int), + cudaMemcpyDeviceToHost)); + + std::vector hostSparseOffsets(mBatchSize + 1); + TLLM_CUDA_CHECK(cudaMemcpy(hostSparseOffsets.data(), mSparseKvOffsetsDevice, hostSparseOffsets.size() * sizeof(int), + cudaMemcpyDeviceToHost)); + TLLM_CUDA_CHECK(cudaDeviceSynchronize()); + + // Process each batch + for (int batch = 0; batch < mBatchSize; ++batch) + { + int const sparse_start_idx = hostSparseOffsets[batch]; + int const sparse_end_idx = hostSparseOffsets[batch + 1]; + int const num_sparse_tokens = sparse_end_idx - sparse_start_idx; + + // Process each head + for (int head = 0; head < mNumKvHeads; ++head) + { + // Process both K and V cache + for (int kv = 0; kv < 2; ++kv) // 0 = K, 1 = V + { + // For each sparse token in this batch + for (int sparseTokenOffset = 0; sparseTokenOffset < num_sparse_tokens; ++sparseTokenOffset) + { + // Calculate index in new format: [num_kv_heads, num_sparse_kv_tokens] + int const global_sparse_idx = sparse_start_idx + sparseTokenOffset; + int const sparse_idx_offset = head * 8 + global_sparse_idx; // 8 is total sparse tokens + int const originalTokenIdx = hostSparseIndices[sparse_idx_offset]; + int const continuousTokenIdx = sparseTokenOffset; + + // Copy from original position to continuous position + for (int dim = 0; dim < mHeadSize; ++dim) + { + // Calculate indices in the flat array + // Layout: [batch][kv][token][head][dim] + size_t srcIdx = batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize) + + kv * (mMaxSeqLen * mNumKvHeads * mHeadSize) + originalTokenIdx * (mNumKvHeads * mHeadSize) + + head * mHeadSize + dim; + + size_t dstIdx = batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize) + + kv * (mMaxSeqLen * mNumKvHeads * mHeadSize) + + continuousTokenIdx * (mNumKvHeads * mHeadSize) + head * mHeadSize + dim; + + if (srcIdx < originalKvCache.size() && dstIdx < expectedKvCache.size()) + { + expectedKvCache[dstIdx] = originalKvCache[srcIdx]; + } + } + } + } + } + } +} + +void SparseKvCacheTest::extractKvCacheFromGpu(std::vector& kvCacheHost) +{ + // Extract KV cache data from GPU KVBlockArray structure + // This is a simplified extraction for testing purposes + + // Calculate total size needed + size_t totalElements = mBatchSize * mMaxSeqLen * mNumKvHeads * mHeadSize * 2; + kvCacheHost.resize(totalElements); + + // For testing, we'll use a simplified approach to read back the cache + // In a real implementation, this would need to handle the block structure properly + + // Calculate pool size + auto const elemSize = sizeof(half); + auto const sizePerToken = mNumKvHeads * mHeadSize * elemSize; + auto const bytesPerBlock = mTokensPerBlock * sizePerToken; + auto const totalBlocks = mBatchSize * mMaxBlocksPerSeq; + auto const poolSize = totalBlocks * bytesPerBlock * 2; // K and V + + // Create temporary buffer to read entire pool + std::vector poolData(poolSize / sizeof(half)); + TLLM_CUDA_CHECK(cudaMemcpy(poolData.data(), mKvCachePool, poolSize, cudaMemcpyDeviceToHost)); + TLLM_CUDA_CHECK(cudaDeviceSynchronize()); + // Reorganize from block structure to linear structure for comparison + // This is a simplified mapping - in reality you'd need to handle block indexing + for (int batch = 0; batch < mBatchSize; ++batch) + { + for (int token = 0; token < mMaxSeqLen; ++token) + { + for (int head = 0; head < mNumKvHeads; ++head) + { + for (int dim = 0; dim < mHeadSize; ++dim) + { + for (int kv = 0; kv < 2; ++kv) + { + // Calculate block coordinates + int blockIdx = token / mTokensPerBlock; + int tokenInBlock = token % mTokensPerBlock; + + // Calculate source index in pool (simplified) + // The layout of a block in the pool is [mTokensPerBlock, mNumKvHeads, mHeadSize] + size_t block_base_pool_idx + = (size_t) (batch * mMaxBlocksPerSeq * 2 + kv * mMaxBlocksPerSeq + blockIdx) + * mTokensPerBlock * mNumKvHeads * mHeadSize; + + size_t inner_block_pool_idx + = (size_t) head * mTokensPerBlock * mHeadSize + (size_t) tokenInBlock * mHeadSize + dim; + + size_t poolIdx = block_base_pool_idx + inner_block_pool_idx; + + // Calculate destination index in linear layout + size_t linearIdx = (size_t) batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize) + + (size_t) kv * (mMaxSeqLen * mNumKvHeads * mHeadSize) + + (size_t) token * (mNumKvHeads * mHeadSize) + (size_t) head * mHeadSize + dim; + + if (poolIdx < poolData.size() && linearIdx < kvCacheHost.size()) + { + kvCacheHost[linearIdx] = poolData[poolIdx]; + } + } + } + } + } + } +} + +void SparseKvCacheTest::initializeKvCacheWithPattern() +{ + + // Calculate pool size + auto const elemSize = sizeof(half); + auto const sizePerToken = mNumKvHeads * mHeadSize * elemSize; + auto const bytesPerBlock = mTokensPerBlock * sizePerToken; + auto const totalBlocks = mBatchSize * mMaxBlocksPerSeq; + auto const poolSize = totalBlocks * bytesPerBlock * 2; // K and V + + // Create host data with recognizable pattern + std::vector poolData(poolSize / sizeof(half)); + for (size_t i = 0; i < poolData.size(); ++i) + { + poolData[i] = half(float(i) / 1000.0f); + } + + // Copy to GPU + TLLM_CUDA_CHECK(cudaMemcpy(mKvCachePool, poolData.data(), poolSize, cudaMemcpyHostToDevice)); + TLLM_CUDA_CHECK(cudaDeviceSynchronize()); +} + +// Main function for standalone compilation +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/examples/llm-api/llm_sparse_attention.py b/examples/llm-api/llm_sparse_attention.py new file mode 100644 index 0000000000..948e06d77e --- /dev/null +++ b/examples/llm-api/llm_sparse_attention.py @@ -0,0 +1,155 @@ +### :title Sparse Attention +### :order 5 +### :section Customization +""" +This example demonstrates how to use sparse attention with TensorRT-LLM. + +Supported sparse attention algorithms: +- RocketKV + +Usage: +```bash +python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048 +``` +""" +import argparse +import json + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig + + +def read_input(input_file): + results = [] + with open(input_file, 'r') as f: + for line in f: + ret = json.loads(line) + results.append(ret) + return results + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_path', + type=str, + default= + "/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct" + ) + parser.add_argument( + '--input_file', + type=str, + default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl" + ) + # Build config + parser.add_argument('--algo', + type=str, + default='ROCKETKV', + choices=['ROCKETKV']) + parser.add_argument('--attention_backend', + type=str, + default='TRTLLM', + choices=['VANILLA', 'TRTLLM']) + parser.add_argument('--window_size', + type=int, + default=32, + help="The window size for RocketKV.") + parser.add_argument('--kernel_size', + type=int, + default=63, + help="The kernel size for RocketKV.") + parser.add_argument('--prompt_budget', + type=int, + default=2048, + help="The prompt budget for RocketKV.") + parser.add_argument("--max_seq_len", + type=int, + default=8192, + help="The maximum sequence length.") + parser.add_argument("--max_batch_size", + type=int, + default=256, + help="The maximum batch size.") + parser.add_argument("--max_new_tokens", + type=int, + default=128, + help="The maximum new tokens.") + parser.add_argument( + "--max_num_tokens", + type=int, + default=8192, + help= + "The maximum total tokens (context + generation) across all sequences in a batch." + ) + parser.add_argument('--tensor_parallel_size', type=int, default=1) + + # KV cache + parser.add_argument('--kv_cache_dtype', type=str, default='auto') + parser.add_argument("--kv_cache_fraction", type=float, default=None) + parser.add_argument('--num_samples', type=int, default=10) + + args = parser.parse_args() + return args + + +def run_RocketKV(args): + data = read_input(args.input_file) + num_samples = args.num_samples if args.num_samples is not None else len( + data) + data = data[:num_samples] + + kv_cache_config = KvCacheConfig( + enable_block_reuse= + False, # sparse attention does not support kv cache reuse now + free_gpu_memory_fraction=args.kv_cache_fraction, + dtype=args.kv_cache_dtype, + ) + sparse_attention_config = RocketSparseAttentionConfig( + window_size=args.window_size, + kernel_size=args.kernel_size, + prompt_budget=args.prompt_budget, + ) + + llm = LLM( + model=args.model_path, + backend='pytorch', + kv_cache_config=kv_cache_config, + attn_backend=args.attention_backend, + sparse_attention_config=sparse_attention_config, + max_batch_size=args.max_batch_size, + max_seq_len=args.max_seq_len, + max_num_tokens=args.max_num_tokens, + tensor_parallel_size=args.tensor_parallel_size, + cuda_graph_config= + None, # sparse attention does not support cuda graph now + ) + + prompts = [] + reference = [] + for sample in data: + prompts.append( + {'prompt': sample['input_context'] + sample['input_query']}) + reference.append(sample['outputs']) + + sampling_params = SamplingParams(add_special_tokens=False, + max_tokens=args.max_new_tokens, + temperature=0.8, + top_p=0.95) + + outputs = llm.generate(prompts, sampling_params) + for idx, output in enumerate(outputs): + print( + f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}' + ) + + +def main(): + args = parse_arguments() + if args.algo == 'ROCKETKV': + run_RocketKV(args) + else: + raise ValueError(f"Invalid algorithm: {args.algo}") + + +if __name__ == "__main__": + main() diff --git a/examples/longbench/README.md b/examples/longbench/README.md new file mode 100644 index 0000000000..4f12c3c672 --- /dev/null +++ b/examples/longbench/README.md @@ -0,0 +1,170 @@ +# LongBench Evaluation with TensorRT-LLM and Sparse Attention + +This directory contains evaluation scripts for both LongBench v1 and LongBench v2 datasets using TensorRT-LLM backend. + +## Environment Setup + +### 1. Clone LongBench Repository + +First, clone the LongBench repository which contains the datasets and evaluation utilities: + +```bash +git clone https://github.com/THUDM/LongBench.git +``` + +### 2. Install Requirements + +Install the required dependencies: + +```bash +pip install -r requirements.txt +``` + +### 3. Directory Structure + +After cloning, your directory structure should look like: + +```text +sparse_attention/ +├── eval_longbench_v1.py # LongBench v1 evaluation script +├── eval_longbench_v2.py # LongBench v2 evaluation script +├── README.md # This file +└── LongBench/ # Cloned LongBench repository + ├── LongBench/ # LongBench v1 data and configs + │ ├── config/ + │ └── ... + ├── config/ # LongBench v2 configs + ├── ... + └── requirements.txt +``` + +## Scripts Overview + +### 1. eval_longbench_v1.py + +This script evaluates models on the **LongBench v1** dataset, which includes multiple specific tasks like narrativeqa, qasper, multifieldqa, etc. Key features: + +- **Dataset**: LongBench v1 with task-specific evaluation +- **Tasks**: Support for 20+ different long-context tasks +- **Prompts**: Task-specific prompts from LongBench v1 configuration +- **Metrics**: Task-specific metrics (F1, ROUGE, classification scores, etc.) +- **Output**: Task-level results with comprehensive summary statistics + +### 2. eval_longbench_v2.py + +This script evaluates models on the **LongBench v2** dataset, which is a standardized multiple-choice format. Key features: + +- **Dataset**: LongBench v2 with unified multiple-choice format +- **Format**: All questions are A/B/C/D multiple choice +- **Context Length**: 8K to 2M words (majority under 128K) +- **Difficulty**: Easy/Hard categorization +- **Length**: Short/Medium/Long categorization +- **Domains**: Various domains (single-doc QA, multi-doc QA, code, etc.) +- **CoT Support**: Chain-of-Thought reasoning support +- **Metrics**: Accuracy with breakdowns by difficulty, length, and domain + +## Usage Examples + +### LongBench v1 Evaluation + +#### Basic Usage (Standard Attention) +```bash +python eval_longbench_v1.py \ + --model_path "/path/to/your/model" \ + --longbench_path ./LongBench \ + --output_dir results/v1_vanilla \ + --attention_backend VANILLA \ + --backend pytorch +``` + +#### Specific tasks With Sparse Attention (RocketKV) +```bash +python eval_longbench_v1.py \ + --model_path "/path/to/your/model" \ + --longbench_path ./LongBench \ + --dataset narrativeqa qasper \ + --output_dir results/v1_rocket \ + --attention_backend VANILLA \ + --backend pytorch \ + --rocket_sparse +``` + +### LongBench v2 Evaluation + +#### Basic Usage (Standard Attention) +```bash +python eval_longbench_v2.py \ + --model_path "/path/to/your/model" \ + --longbench_path ./LongBench \ + --output_dir results/v2_vanilla +``` + +#### With Chain-of-Thought Reasoning +```bash +python eval_longbench_v2.py \ + --model_path "/path/to/your/model" \ + --longbench_path ./LongBench \ + --output_dir results/v2_cot \ + --cot +``` + +#### Filter by Difficulty/Length/Domain +```bash +# Easy questions only +python eval_longbench_v2.py \ + --model_path "/path/to/your/model" \ + --longbench_path ./LongBench \ + --output_dir results/v2_easy \ + --difficulty easy + +# Long context only +python eval_longbench_v2.py \ + --model_path "/path/to/your/model" \ + --longbench_path ./LongBench \ + --output_dir results/v2_long \ + --length long + +# Specific domain +python eval_longbench_v2.py \ + --model_path "/path/to/your/model" \ + --longbench_path ./LongBench \ + --output_dir results/v2_code \ + --domain "Code" +``` + +#### Limited Sample Evaluation (for testing) +```bash +python eval_longbench_v2.py \ + --model_path "/path/to/your/model" \ + --longbench_path ./LongBench \ + --output_dir results/v2_test \ + --num_samples 10 +``` + +## Output Structure + +### LongBench v1 Output + +```text +results/v1_experiment/ +├── config.json # Experiment configuration +├── overall_summary.json # Overall experiment summary +├── narrativeqa/ +│ ├── narrativeqa_results.jsonl # Detailed results +│ ├── narrativeqa_summary.json # Task summary +│ └── pred/ +│ └── narrativeqa.jsonl # Predictions in LongBench format +├── qasper/ +│ └── ... +└── ... +``` + +### LongBench v2 Output + +```text +results/v2_experiment/ +├── config.json # Experiment configuration +├── summary.json # Evaluation summary with metrics +├── longbench_v2_results.jsonl # Detailed results +└── predictions.jsonl # Predictions in LongBench v2 format +``` diff --git a/examples/longbench/eval_longbench_v1.py b/examples/longbench/eval_longbench_v1.py new file mode 100644 index 0000000000..2054266217 --- /dev/null +++ b/examples/longbench/eval_longbench_v1.py @@ -0,0 +1,797 @@ +#!/usr/bin/env python3 +""" +LongBench v1 evaluation script with TensorRT-LLM and sparse attention. + +Usage: + python longbench_rocket_eval.py --dataset narrativeqa --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ + + # Run all LongBench tasks + python longbench_rocket_eval.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --token_budget 2048 --rocket_sparse +""" + +import argparse +import json +import os +import sys +import time +from datetime import datetime +from typing import Any, Dict, List, Tuple + +import numpy as np +from datasets import load_dataset +from transformers import AutoTokenizer + +# Add tensorrt_llm imports +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig +from tensorrt_llm.logger import logger + +LONGBENCH_DATASETS = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ + "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ + "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] + +# Task categorization +TASK_DATASETS = { + 'single_doc_qa': + ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh"], + 'multi_doc_qa': ["hotpotqa", "2wikimqa", "musique", "dureader"], + 'summarization': ["gov_report", "qmsum", "multi_news", "vcsum"], + 'few_shots': ["trec", "triviaqa", "samsum", "lsht"], + 'synthetic': + ["passage_count", "passage_retrieval_en", "passage_retrieval_zh"], + 'code': ["lcc", "repobench-p"] +} + +# Chat templates mapping +CHAT_TEMPLATES = { + "llama3.1-8b-instruct": "llama3", + "llama3-8b-instruct": "llama3", + "mistral-7b-instruct-v0.2": "mistral", + "longchat-7b-v1.5-32k": "vicuna" +} + + +def parse_arguments() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="LongBench evaluation with TensorRT-LLM and RocketKV") + + # Model and data arguments + parser.add_argument('--model_path', + type=str, + required=True, + help='Path to model (HF model name or local path)') + parser.add_argument('--dataset', + type=str, + nargs='+', + choices=LONGBENCH_DATASETS, + help='LongBench datasets to evaluate on') + parser.add_argument('--run_all_tasks', + action='store_true', + help='Run evaluation on all LongBench tasks') + parser.add_argument('--longbench_path', + type=str, + default='./LongBench', + help='Path to LongBench directory') + + # Output arguments + parser.add_argument('--output_dir', + type=str, + required=True, + help='Directory to save results') + parser.add_argument('--exp_name', + type=str, + default=None, + help='Experiment name (auto-generated if not provided)') + + # Model configuration + parser.add_argument('--attention_backend', + type=str, + default='VANILLA', + choices=['VANILLA', 'TRTLLM', 'FLASHINFER'], + help='Attention backend to use') + parser.add_argument('--backend', + type=str, + default='pytorch', + choices=['pytorch', 'tensorrt'], + help='LLM backend to use') + parser.add_argument('--chat_template', + type=str, + default='auto', + help='Chat template to use (auto-detect if "auto")') + + # Sequence and batch configuration + parser.add_argument('--max_seq_len', + type=int, + default=133120, + help='Maximum sequence length') + parser.add_argument('--max_batch_size', + type=int, + default=1, + help='Maximum batch size') + parser.add_argument('--max_new_tokens', + type=int, + default=256, + help='Maximum new tokens to generate') + parser.add_argument( + '--max_num_tokens', + type=int, + default=133120, + help='Maximum total tokens across all sequences in a batch') + parser.add_argument('--tensor_parallel_size', + type=int, + default=1, + help='Tensor parallel size') + + # RocketKV configuration + parser.add_argument('--rocket_sparse', + action='store_true', + help='Use rocket sparse attention') + parser.add_argument('--token_budget', + type=int, + default=2048, + help='Token budget for RocketKV (prompt_budget)') + parser.add_argument('--window_size', + type=int, + default=32, + help='Window size for RocketKV') + parser.add_argument('--kernel_size', + type=int, + default=63, + help='Kernel size for RocketKV') + parser.add_argument('--topr', + type=int, + default=90, + help='Top-r for RocketKV') + + # KV cache configuration + parser.add_argument('--kv_cache_dtype', + type=str, + default='auto', + help='KV cache data type') + parser.add_argument('--kv_cache_fraction', + type=float, + default=0.7, + help='Fraction of GPU memory for KV cache') + + # Evaluation parameters + parser.add_argument('--num_samples', + type=int, + default=None, + help='Number of samples to evaluate (None for all)') + parser.add_argument('--start_idx', + type=int, + default=0, + help='Start index for evaluation') + + # System arguments + parser.add_argument('--log_level', + type=str, + default='info', + choices=['debug', 'info', 'warning', 'error'], + help='Logging level') + parser.add_argument('--seed', type=int, default=42, help='Random seed') + + args = parser.parse_args() + + # Validation + if not args.run_all_tasks and not args.dataset: + parser.error("Must specify either --dataset or --run_all_tasks") + + return args + + +def setup_longbench_imports(longbench_path: str): + """Add LongBench to Python path and import required modules.""" + longbench_dir = os.path.join(longbench_path, "LongBench") # for v1 + if not os.path.exists(longbench_dir): + raise FileNotFoundError( + f"LongBench directory not found: {longbench_dir}") + + # Add to path + if longbench_dir not in sys.path: + sys.path.insert(0, longbench_dir) + + +def load_longbench_config( + longbench_path: str) -> Tuple[Dict[str, str], Dict[str, int]]: + """Load LongBench configuration files.""" + config_dir = os.path.join(longbench_path, "LongBench", "config") + + # Load dataset2prompt.json + prompt_file = os.path.join(config_dir, "dataset2prompt.json") + with open(prompt_file, 'r', encoding='utf-8') as f: + dataset2prompt = json.load(f) + + # Load dataset2maxlen.json + maxlen_file = os.path.join(config_dir, "dataset2maxlen.json") + with open(maxlen_file, 'r', encoding='utf-8') as f: + dataset2maxlen = json.load(f) + + return dataset2prompt, dataset2maxlen + + +# LongBench's build_chat function (simplified version) +def build_chat(tokenizer, prompt, chat_template): + """Build chat prompt following LongBench's approach.""" + if chat_template == "vicuna" or chat_template == "longchat": + try: + from fastchat.model import get_conversation_template + conv = get_conversation_template("vicuna") + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + except ImportError: + # Fallback if fastchat is not available + prompt = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: {prompt}\nASSISTANT:" + elif chat_template == "llama2": + prompt = f"[INST]{prompt}[/INST]" + elif chat_template == "llama3": + prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + elif chat_template == "mistral": + prompt = f"[INST] {prompt} [/INST]" + # For other templates or "none", return prompt as-is + return prompt + + +def determine_chat_template(model_path: str, chat_template: str) -> str: + """Determine chat template based on model path.""" + if chat_template != 'auto': + return chat_template + + model_path_lower = model_path.lower() + + for model_key, template in CHAT_TEMPLATES.items(): + if model_key.replace('-', '').replace('.', + '') in model_path_lower.replace( + '-', '').replace('.', ''): + return template + + # Default fallback + if 'llama' in model_path_lower: + return 'llama3' + elif 'mistral' in model_path_lower: + return 'mistral' + else: + return 'none' # No special formatting + + +def post_process(pred: str, chat_template: str, dataset: str) -> str: + """Post-process prediction following LongBench's approach.""" + pred = pred.split("")[0] + elif "llama2" in chat_template.lower(): + pred = (pred.split("(Document")[0].split("\n\nQuestion")[0].split( + "\n\nAnswer")[0].split("[INST]")[0].split("[/INST]")[0].split( + "(Passage")[0].strip()) + if dataset == "samsum": + pred = pred.split("\n")[0].strip() + + return pred + + +def format_prompt_style(sample: Dict[str, Any], instruction: str, + chat_template: str, dataset: str, tokenizer) -> str: + """Format prompt following LongBench's approach.""" + # First format the instruction using the sample data + prompt = instruction.format(**sample) + + if dataset not in [ + "trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p" + ]: + prompt = build_chat(tokenizer, prompt, chat_template) + + return prompt + + +def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]: + """Initialize LLM and tokenizer.""" + logger.info(f"Initializing LLM with model: {args.model_path}") + + try: + # Configure KV cache + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, # RocketKV doesn't support KV cache reuse + ) + + if args.rocket_sparse: + # Configure RocketKV sparse attention + sparse_attention_config = RocketSparseAttentionConfig( + window_size=args.window_size, + kernel_size=args.kernel_size, + prompt_budget=args.token_budget, + topr=args.topr, + ) + logger.info(f"Using RocketKV sparse attention") + else: + sparse_attention_config = None + logger.info("Using standard attention") + + # Initialize LLM + llm = LLM( + model=args.model_path, + backend=args.backend, + kv_cache_config=kv_cache_config, + max_batch_size=args.max_batch_size, + attn_backend=args.attention_backend, + sparse_attention_config=sparse_attention_config, + tensor_parallel_size=args.tensor_parallel_size, + max_seq_len=args.max_seq_len, + max_num_tokens=args.max_num_tokens, + cuda_graph_config=None, + torch_compile_config=None, + ) + + # Initialize tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + + logger.info("LLM and tokenizer initialized successfully") + + return llm, tokenizer + + except Exception as e: + logger.error(f"Failed to initialize LLM: {e}") + raise e + + +def evaluate_single_dataset( + dataset: str, llm: LLM, tokenizer: AutoTokenizer, + args: argparse.Namespace) -> Tuple[List[Dict], float]: + """Evaluate a single dataset.""" + setup_longbench_imports(args.longbench_path) + + dataset2prompt, dataset2maxlen = load_longbench_config(args.longbench_path) + + # Load dataset + logger.info(f"Loading dataset: {dataset}") + data = [ + data_sample for data_sample in load_dataset( + 'THUDM/LongBench', dataset, split='test', trust_remote_code=True) + ] + + # Apply data filtering + if args.num_samples: + end_idx = min(args.start_idx + args.num_samples, len(data)) + filtered_data = data[args.start_idx:end_idx] + else: + filtered_data = data[args.start_idx:] + + logger.info(f"Dataset {dataset}: {len(filtered_data)} samples to evaluate") + + # Determine chat template + chat_template = determine_chat_template(args.model_path, args.chat_template) + logger.info(f"Using chat template: {chat_template}") + + # Create sampling parameters + max_new_tokens = dataset2maxlen[dataset] + prompt_format = dataset2prompt[dataset] + + # Set up extra end token ids + extra_end_token_ids = [] + if chat_template == "llama3": + eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0] + extra_end_token_ids.append(eot_id) + logger.info(f"Added llama3 end token: {eot_id}") + + if chat_template == "qwen": + im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0] + extra_end_token_ids.append(im_end_id) + logger.info(f"Added qwen end token: {im_end_id}") + + if dataset == "samsum": + newline_id = tokenizer.encode("\n", add_special_tokens=False)[-1] + extra_end_token_ids.append(newline_id) + logger.info(f"Added samsum newline token: {newline_id}") + + # Prepare prompts + prompts = [] + for sample in filtered_data: + formatted_prompt = format_prompt_style(sample, prompt_format, + chat_template, dataset, + tokenizer) + prompts.append(formatted_prompt) + + if len(prompts) == 0: + logger.warning(f"No prompts to evaluate for dataset {dataset}") + return [], 0.0 + + # Run inference + logger.info(f"Starting inference for {len(prompts)} samples...") + start_time = time.time() + + sampling_params = SamplingParams( + max_tokens=max_new_tokens, + temperature=0.8, + top_p=0.95, + stop_token_ids=extra_end_token_ids if extra_end_token_ids else None, + ) + + outputs = llm.generate(prompts, sampling_params) + + inference_time = time.time() - start_time + logger.info( + f"Inference completed in {inference_time:.2f} seconds, average time per sample: {inference_time/len(prompts):.3f} seconds" + ) + + # Prepare results + results = [] + for i, (sample, output) in enumerate(zip(filtered_data, outputs)): + prediction = output.outputs[0].text.strip() + processed_prediction = post_process(prediction, chat_template, dataset) + + result = { + 'sample_id': args.start_idx + i, + 'input': sample.get('input', ''), + 'context': sample.get('context', ''), + 'answers': sample.get('answers', []), + 'all_classes': sample.get('all_classes', []), + 'prediction': processed_prediction, + 'raw_prediction': prediction, + 'prompt_length': len(output.prompt_token_ids), + 'output_length': len(output.outputs[0].token_ids), + 'inference_time': getattr(output, 'inference_time', None), + 'length': sample.get('length', 0) + } + results.append(result) + + return results, inference_time + + +def calculate_metrics( + dataset: str, predictions: List[str], answers_list: List[List[str]], + all_classes_list: List[List[str]], + longbench_path: str) -> Tuple[Dict[str, float], List[float]]: + """Calculate evaluation metrics for a dataset following LongBench's implementation.""" + + # Setup LongBench imports + setup_longbench_imports(longbench_path) + + # Import LongBench metrics + from metrics import (classification_score, code_sim_score, count_score, + qa_f1_score, qa_f1_zh_score, retrieval_score, + retrieval_zh_score, rouge_score, rouge_zh_score) + + # Mapping of datasets to their metric functions (from LongBench) + dataset2metric = { + "narrativeqa": qa_f1_score, + "qasper": qa_f1_score, + "multifieldqa_en": qa_f1_score, + "multifieldqa_zh": qa_f1_zh_score, + "hotpotqa": qa_f1_score, + "2wikimqa": qa_f1_score, + "musique": qa_f1_score, + "dureader": rouge_zh_score, + "gov_report": rouge_score, + "qmsum": rouge_score, + "multi_news": rouge_score, + "vcsum": rouge_zh_score, + "trec": classification_score, + "triviaqa": qa_f1_score, + "samsum": rouge_score, + "lsht": classification_score, + "passage_retrieval_en": retrieval_score, + "passage_count": count_score, + "passage_retrieval_zh": retrieval_zh_score, + "lcc": code_sim_score, + "repobench-p": code_sim_score, + } + + if dataset not in dataset2metric: + # Fallback to simple exact match with cleaning + total_score = 0 + scores = [] + for pred, answers in zip(predictions, answers_list): + cleaned_pred = pred.lstrip('\n').split('\n')[0].strip() + score = max([ + 1.0 if cleaned_pred.lower() == ans.strip().lower() else 0.0 + for ans in answers + ]) + scores.append(score) + total_score += score + return { + "exact_match": round(100 * total_score / len(predictions), 2) + }, scores + + metric_func = dataset2metric[dataset] + total_score = 0.0 + scores = [] + + # Follow LongBench's scorer function exactly + for pred, ground_truths, all_classes in zip(predictions, answers_list, + all_classes_list): + score = 0.0 + + # Apply the same prediction cleaning as LongBench + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + pred = pred.lstrip('\n').split('\n')[0] + + # For code datasets, apply additional cleaning + if dataset in ["lcc", "repobench-p"]: + # This cleaning is done inside code_sim_score, but let's also apply it here for consistency + all_lines = pred.lstrip('\n').split('\n') + for line in all_lines: + if ('`' not in line) and ('#' not in line) and ('//' + not in line): + pred = line + break + + # Calculate max score across all reference answers (exactly as in LongBench) + for ground_truth in ground_truths: + score = max( + score, metric_func(pred, ground_truth, all_classes=all_classes)) + + scores.append(score) + total_score += score + + final_score = round(100 * total_score / len(predictions), 2) + return {metric_func.__name__: final_score}, scores + + +def calculate_task_summary(all_results: Dict[str, Dict]) -> Dict[str, Any]: + """Calculate task-level summary statistics following long_bench_tasks_summary.py approach.""" + logger.info("Calculating task-level summary statistics...") + + summary = {} + ind_dataset_result = {} + task_ave_result = {} + + NA_flag = False + + # Get individual dataset results + for dataset in LONGBENCH_DATASETS: + if dataset in all_results and 'metrics' in all_results[dataset]: + metrics = all_results[dataset]['metrics'] + # Get the first (and usually only) metric value + if metrics: + metric_key = list(metrics.keys())[0] + val = metrics[metric_key] + ind_dataset_result[dataset] = val + else: + ind_dataset_result[dataset] = 'N/A' + NA_flag = True + else: + ind_dataset_result[dataset] = 'N/A' + NA_flag = True + + summary['individual_dataset_result'] = ind_dataset_result + + # Calculate task-average results + for task, datasets in TASK_DATASETS.items(): + task_NA_flag = False + task_ave_result[task] = 0 + valid_count = 0 + + for dataset in datasets: + if dataset in ind_dataset_result and ind_dataset_result[ + dataset] != 'N/A': + task_ave_result[task] += ind_dataset_result[dataset] + valid_count += 1 + else: + task_NA_flag = True + + if task_NA_flag or valid_count == 0: + task_ave_result[task] = 'N/A' + else: + task_ave_result[task] = np.round(task_ave_result[task] / + valid_count, + decimals=2) + + summary['task_average_result'] = task_ave_result + + # Calculate overall LongBench average result (excluding passage_count as in original) + if NA_flag: + summary['LB_average_result'] = 'N/A' + else: + average_result = 0 + valid_count = 0 + for dataset in LONGBENCH_DATASETS: + if dataset != 'passage_count' and dataset in ind_dataset_result: + if ind_dataset_result[dataset] != 'N/A': + average_result += ind_dataset_result[dataset] + valid_count += 1 + + if valid_count > 0: + summary['LB_average_result'] = np.round(average_result / + valid_count, + decimals=2) + else: + summary['LB_average_result'] = 'N/A' + + # Log summary statistics + logger.info("Task Summary Results:") + logger.info(f"Overall LongBench Average: {summary['LB_average_result']}") + for task, score in task_ave_result.items(): + logger.info(f"{task}: {score}") + + return summary + + +def save_results(results: List[Dict], dataset: str, args: argparse.Namespace, + inference_time: float, output_dir: str): + """Save evaluation results in format compatible with LongBench.""" + task_output_dir = os.path.join(output_dir, dataset) + os.makedirs(task_output_dir, exist_ok=True) + + # Extract predictions, answers, and all_classes for evaluation + predictions = [r['prediction'] for r in results] + answers_list = [r['answers'] for r in results] + all_classes_list = [r.get('all_classes', []) for r in results] + + # Calculate metrics + processed_results, scores = calculate_metrics(dataset, predictions, + answers_list, + all_classes_list, + args.longbench_path) + logger.info(f"Evaluation metrics: {processed_results}") + + # Save detailed results for manual inspection + results_file = os.path.join(task_output_dir, f"{dataset}_results.jsonl") + with open(results_file, 'w', encoding='utf-8') as f: + for result in results: + json.dump(result, f, ensure_ascii=False) + f.write('\n') + + # Save prediction results in LongBench format for evaluation + pred_dir = os.path.join(task_output_dir, "pred") + os.makedirs(pred_dir, exist_ok=True) + pred_file = os.path.join(pred_dir, f"{dataset}.jsonl") + + with open(pred_file, 'w', encoding='utf-8') as f: + for idx, result in enumerate(results): + pred_data = { + "pred": result['prediction'], + "answers": result['answers'], + "all_classes": result.get('all_classes', []), + "length": result.get('length', 0), + "score": scores[idx] + } + json.dump(pred_data, f, ensure_ascii=False) + f.write('\n') + + # Create summary following LongBench format + config = { + 'pipeline_params': { + 'model_name': args.model_path, + 'method': args.attention_backend, + 'token_budget': args.token_budget, + 'max_seq_len': args.max_seq_len, + 'max_new_tokens': args.max_new_tokens, + 'window_size': args.window_size, + 'kernel_size': args.kernel_size, + 'num_processes': 1, # Single process + 'devices': "0" # Single device + }, + 'eval_params': { + 'dataset': dataset, + 'num_samples': len(results) + }, + 'eval_results': { + 'processed_results': processed_results + }, + 'management': { + 'output_folder_dir': task_output_dir, + 'exp_desc': + f'{dataset}_{os.path.basename(args.model_path)}_{args.attention_backend}_{args.token_budget}', + 'total_inference_time': inference_time, + 'avg_inference_time': inference_time / len(results), + 'evaluation_timestamp': datetime.now().isoformat() + } + } + + # Save summary + summary_file = os.path.join(task_output_dir, f"{dataset}_summary.json") + with open(summary_file, 'w', encoding='utf-8') as f: + json.dump(config, f, indent=2, ensure_ascii=False) + + logger.info(f"The results of {dataset} are saved to {task_output_dir}") + + return processed_results + + +def main(): + """Main evaluation function.""" + args = parse_arguments() + logger.set_level(args.log_level) + + # Setup experiment name + if not args.exp_name: + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + model_name = os.path.basename(args.model_path).replace('/', '_') + args.exp_name = f"longbench_{model_name}_{timestamp}" + + output_dir = os.path.join(args.output_dir, args.exp_name) + + logger.info( + "=========== LongBench Evaluation with TensorRT-LLM ===========") + + os.makedirs(output_dir, exist_ok=True) + # Save configuration + config_file = os.path.join(output_dir, "config.json") + with open(config_file, 'w') as f: + json.dump(vars(args), f, indent=2) + logger.info(f"Configuration saved to {config_file}") + + # Determine datasets to evaluate + if args.run_all_tasks: + datasets = LONGBENCH_DATASETS + logger.info(f"Running evaluation on full LongBench datasets") + else: + datasets = args.dataset + logger.info(f"Running evaluation on datasets: {args.dataset}") + + # Initialize LLM and tokenizer + llm, tokenizer = initialize_llm(args) + + # Process datasets sequentially + all_results = {} + for dataset_idx, dataset in enumerate(datasets): + logger.info(f"{'='*30}") + logger.info( + f"Processing dataset {dataset_idx+1}/{len(datasets)}: {dataset}...") + + # Evaluate the dataset + results, inference_time = evaluate_single_dataset( + dataset, llm, tokenizer, args) + + # Save results and get metrics + processed_results = save_results(results, dataset, args, inference_time, + output_dir) + + all_results[dataset] = { + 'num_samples': len(results), + 'inference_time': inference_time, + 'output_dir': output_dir, + 'metrics': processed_results + } + logger.info(f"Dataset {dataset} completed successfully") + + total_time = sum(all_results[d]['inference_time'] for d in all_results) + + # Calculate task-level summary + task_summary = calculate_task_summary(all_results) + + # Save overall summary with task statistics + overall_summary = { + 'experiment_name': + args.exp_name, + 'total_evaluation_time': + total_time, + 'evaluated_datasets': + list(all_results.keys()), + 'successful_datasets': + [d for d, r in all_results.items() if 'error' not in r], + 'failed_datasets': [d for d, r in all_results.items() if 'error' in r], + 'results_by_dataset': + all_results, + 'task_summary': + task_summary, # Add task-level summary + 'configuration': + vars(args) + } + + overall_summary_file = os.path.join(output_dir, "overall_summary.json") + with open(overall_summary_file, 'w') as f: + json.dump(overall_summary, f, indent=2) + + logger.info(f"\n{'-'*80}") + logger.info( + f"Evaluation completed. Overall summary saved to: {overall_summary_file}" + ) + logger.info(f"Total time: {total_time:.2f} seconds") + + # Print final summary + if task_summary['LB_average_result'] != 'N/A': + logger.info(f"FINAL RESULTS:") + logger.info( + f"LongBench Overall Average: {task_summary['LB_average_result']}") + logger.info(f"Task-level results:") + for task, score in task_summary['task_average_result'].items(): + logger.info(f" {task}: {score}") + + if overall_summary['failed_datasets']: + logger.warning(f"Failed datasets: {overall_summary['failed_datasets']}") + + +if __name__ == '__main__': + main() diff --git a/examples/longbench/eval_longbench_v2.py b/examples/longbench/eval_longbench_v2.py new file mode 100644 index 0000000000..e9bd3d3800 --- /dev/null +++ b/examples/longbench/eval_longbench_v2.py @@ -0,0 +1,771 @@ +#!/usr/bin/env python3 +""" +LongBench v2 evaluation script with TensorRT-LLM and sparse attention. + +Usage: + python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ + + # Run all LongBench v2 samples + python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ + + # Enable CoT reasoning + python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --cot + + # Run with different difficulty levels + python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --difficulty easy +""" + +import argparse +import json +import os +import re +import time +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +from datasets import load_dataset +from transformers import AutoTokenizer + +# Add tensorrt_llm imports +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig +from tensorrt_llm.logger import logger + +# Chat templates mapping +CHAT_TEMPLATES = { + "llama3.1-8b-instruct": "llama3", + "llama3-8b-instruct": "llama3", + "mistral-7b-instruct-v0.2": "mistral", + "longchat-7b-v1.5-32k": "vicuna" +} + + +def parse_arguments() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="LongBench v2 evaluation with TensorRT-LLM and RocketKV") + + # Model and data arguments + parser.add_argument('--model_path', + type=str, + required=True, + help='Path to model (HF model name or local path)') + parser.add_argument('--longbench_path', + type=str, + default='./LongBench', + help='Path to LongBench directory') + + # Output arguments + parser.add_argument('--output_dir', + type=str, + required=True, + help='Directory to save results') + parser.add_argument('--exp_name', + type=str, + default=None, + help='Experiment name (auto-generated if not provided)') + + # Model configuration + parser.add_argument('--attention_backend', + type=str, + default='VANILLA', + choices=['VANILLA', 'TRTLLM', 'FLASHINFER'], + help='Attention backend to use') + parser.add_argument('--backend', + type=str, + default='pytorch', + choices=['pytorch', 'tensorrt'], + help='LLM backend to use') + parser.add_argument('--chat_template', + type=str, + default='auto', + help='Chat template to use (auto-detect if "auto")') + + # Sequence and batch configuration + parser.add_argument('--max_seq_len', + type=int, + default=133120, + help='Maximum sequence length') + parser.add_argument('--max_batch_size', + type=int, + default=1, + help='Maximum batch size') + parser.add_argument('--max_new_tokens', + type=int, + default=256, + help='Maximum new tokens to generate') + parser.add_argument( + '--max_num_tokens', + type=int, + default=133120, + help='Maximum total tokens across all sequences in a batch') + parser.add_argument('--tensor_parallel_size', + type=int, + default=1, + help='Tensor parallel size') + + # RocketKV configuration + parser.add_argument('--rocket_sparse', + action='store_true', + help='Use rocket sparse attention') + parser.add_argument('--token_budget', + type=int, + default=2048, + help='Token budget for RocketKV (prompt_budget)') + parser.add_argument('--window_size', + type=int, + default=32, + help='Window size for RocketKV') + parser.add_argument('--kernel_size', + type=int, + default=63, + help='Kernel size for RocketKV') + parser.add_argument('--topr', + type=int, + default=90, + help='Top-r for RocketKV') + + # KV cache configuration + parser.add_argument('--kv_cache_dtype', + type=str, + default='auto', + help='KV cache data type') + parser.add_argument('--kv_cache_fraction', + type=float, + default=0.7, + help='Fraction of GPU memory for KV cache') + + # LongBench v2 specific arguments + parser.add_argument('--cot', + action='store_true', + help='Enable Chain-of-Thought reasoning') + parser.add_argument('--no_context', + action='store_true', + help='Test without long context (pure memorization)') + parser.add_argument('--rag', + type=int, + default=0, + help='Use top-N retrieved contexts (0 to disable)') + + # Evaluation parameters + parser.add_argument('--num_samples', + type=int, + default=None, + help='Number of samples to evaluate (None for all)') + parser.add_argument('--start_idx', + type=int, + default=0, + help='Start index for evaluation') + parser.add_argument('--difficulty', + type=str, + choices=['easy', 'hard'], + default=None, + help='Filter by difficulty level') + parser.add_argument('--length', + type=str, + choices=['short', 'medium', 'long'], + default=None, + help='Filter by length category') + parser.add_argument('--domain', + type=str, + default=None, + help='Filter by domain') + parser.add_argument('--max_len', + type=int, + default=120000, + help='Maximum prompt length in tokens for truncation') + + # System arguments + parser.add_argument('--log_level', + type=str, + default='info', + choices=['debug', 'info', 'warning', 'error'], + help='Logging level') + parser.add_argument('--seed', type=int, default=42, help='Random seed') + + return parser.parse_args() + + +def load_longbench_v2_config(longbench_path: str) -> Dict[str, Any]: + """Load LongBench v2 configuration files.""" + config_dir = os.path.join(longbench_path, "config") + + # Load model2maxlen.json for v2 + maxlen_file = os.path.join(config_dir, "model2maxlen.json") + with open(maxlen_file, 'r', encoding='utf-8') as f: + model2maxlen = json.load(f) + + # Load prompt templates + prompts_dir = os.path.join(longbench_path, "prompts") + + templates = {} + template_files = { + '0shot': '0shot.txt', + '0shot_cot': '0shot_cot.txt', + '0shot_cot_ans': '0shot_cot_ans.txt', + '0shot_no_context': '0shot_no_context.txt', + '0shot_rag': '0shot_rag.txt' + } + + for template_name, filename in template_files.items(): + template_path = os.path.join(prompts_dir, filename) + if os.path.exists(template_path): + with open(template_path, 'r', encoding='utf-8') as f: + templates[template_name] = f.read() + + return {'model2maxlen': model2maxlen, 'templates': templates} + + +def build_chat(tokenizer, prompt, chat_template): + """Build chat prompt following LongBench's approach.""" + if chat_template == "vicuna" or chat_template == "longchat": + try: + from fastchat.model import get_conversation_template + conv = get_conversation_template("vicuna") + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + except ImportError: + # Fallback if fastchat is not available + prompt = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: {prompt}\nASSISTANT:" + elif chat_template == "llama2": + prompt = f"[INST]{prompt}[/INST]" + elif chat_template == "llama3": + prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + elif chat_template == "mistral": + prompt = f"[INST] {prompt} [/INST]" + # For other templates or "none", return prompt as-is + return prompt + + +def determine_chat_template(model_path: str, chat_template: str) -> str: + """Determine chat template based on model path.""" + if chat_template != 'auto': + return chat_template + + model_path_lower = model_path.lower() + + for model_key, template in CHAT_TEMPLATES.items(): + if model_key.replace('-', '').replace('.', + '') in model_path_lower.replace( + '-', '').replace('.', ''): + return template + + # Default fallback + if 'llama' in model_path_lower: + return 'llama3' + elif 'mistral' in model_path_lower: + return 'mistral' + else: + return 'none' # No special formatting + + +def extract_answer(response: str) -> Optional[str]: + """Extract answer from response following LongBench v2's approach.""" + response = response.replace('*', '') + + # Try to extract answer in format "The correct answer is (X)" + match = re.search(r'The correct answer is \(([A-D])\)', response) + if match: + return match.group(1) + + # Try to extract answer in format "The correct answer is X" + match = re.search(r'The correct answer is ([A-D])', response) + if match: + return match.group(1) + + # Try to extract any single letter A-D + match = re.search(r'\b([A-D])\b', response) + if match: + return match.group(1) + + return None + + +def post_process(pred: str, chat_template: str) -> str: + """Post-process prediction following LongBench v2's approach.""" + pred = pred.split("")[0] + elif "llama2" in chat_template.lower(): + pred = (pred.split("(Document")[0].split("\n\nQuestion")[0].split( + "\n\nAnswer")[0].split("[INST]")[0].split("[/INST]")[0].split( + "(Passage")[0].strip()) + + return pred + + +def truncate_prompt(prompt: str, tokenizer: AutoTokenizer, max_len: int) -> str: + """Truncate prompt following LongBench v2's approach.""" + # Encode the prompt using the tokenizer + input_ids = tokenizer.encode(prompt, add_special_tokens=False) + + # If prompt exceeds max_len, truncate by taking first half and last half + if len(input_ids) > max_len: + half = max_len // 2 + truncated_ids = input_ids[:half] + input_ids[-half:] + # Decode back to text + prompt = tokenizer.decode(truncated_ids, skip_special_tokens=True) + + return prompt + + +def format_prompt(sample: Dict[str, Any], template: str, + args: argparse.Namespace) -> str: + """Format prompt for LongBench v2.""" + context = sample['context'] + + # Handle RAG mode + if args.rag > 0 and 'retrieved_context' in sample: + retrieved = sample["retrieved_context"][:args.rag] + retrieved = sorted(retrieved, key=lambda x: x.get('c_idx', 0)) + context = '\n\n'.join([ + f"Retrieved chunk {idx+1}: {x['content']}" + for idx, x in enumerate(retrieved) + ]) + + # Handle no context mode + if args.no_context: + context = "" + + # Format the prompt using the template + prompt = template.replace('$DOC$', context.strip()) + prompt = prompt.replace('$Q$', sample['question'].strip()) + prompt = prompt.replace('$C_A$', sample['choice_A'].strip()) + prompt = prompt.replace('$C_B$', sample['choice_B'].strip()) + prompt = prompt.replace('$C_C$', sample['choice_C'].strip()) + prompt = prompt.replace('$C_D$', sample['choice_D'].strip()) + + return prompt + + +def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]: + """Initialize LLM and tokenizer.""" + logger.info(f"Initializing LLM with model: {args.model_path}") + + try: + # Configure KV cache + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, # RocketKV doesn't support KV cache reuse + ) + + if args.rocket_sparse: + # Configure RocketKV sparse attention + sparse_attention_config = RocketSparseAttentionConfig( + window_size=args.window_size, + kernel_size=args.kernel_size, + prompt_budget=args.token_budget, + topr=args.topr, + ) + logger.info(f"Using RocketKV sparse attention") + else: + sparse_attention_config = None + logger.info("Using standard attention") + + # Initialize LLM + llm = LLM( + model=args.model_path, + backend=args.backend, + kv_cache_config=kv_cache_config, + attn_backend=args.attention_backend, + sparse_attention_config=sparse_attention_config, + tensor_parallel_size=args.tensor_parallel_size, + max_seq_len=args.max_seq_len, + max_num_tokens=args.max_num_tokens, + cuda_graph_config=None, + torch_compile_config=None, + ) + + # Initialize tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + + logger.info("LLM and tokenizer initialized successfully") + + return llm, tokenizer + + except Exception as e: + logger.error(f"Failed to initialize LLM: {e}") + raise e + + +def evaluate_longbench_v2(llm: LLM, tokenizer: AutoTokenizer, + args: argparse.Namespace) -> Tuple[List[Dict], float]: + """Evaluate on LongBench v2 dataset.""" + + # Load LongBench v2 configuration + config = load_longbench_v2_config(args.longbench_path) + + # Determine max_len for the model if not explicitly set via arguments + model_name = os.path.basename(args.model_path) + if model_name in config[ + 'model2maxlen']: # Use default from config if available + max_len = config['model2maxlen'][model_name] + logger.info(f"Using model-specific max_len: {max_len} for {model_name}") + else: + max_len = args.max_len + logger.info(f"Using max_len: {max_len}") + + # Update args with the determined max_len + args.max_len = max_len + + # Load dataset + logger.info(f"Loading LongBench v2 dataset...") + dataset = load_dataset('THUDM/LongBench-v2', + split='train', + trust_remote_code=True) + data = [item for item in dataset] + + # Apply filters + filtered_data = data + + if args.difficulty: + filtered_data = [ + item for item in filtered_data + if item['difficulty'] == args.difficulty + ] + logger.info( + f"Filtered by difficulty '{args.difficulty}': {len(filtered_data)} samples" + ) + + if args.length: + filtered_data = [ + item for item in filtered_data if item['length'] == args.length + ] + logger.info( + f"Filtered by length '{args.length}': {len(filtered_data)} samples") + + if args.domain: + filtered_data = [ + item for item in filtered_data if item['domain'] == args.domain + ] + logger.info( + f"Filtered by domain '{args.domain}': {len(filtered_data)} samples") + + # Apply start_idx and num_samples + if args.num_samples: + end_idx = min(args.start_idx + args.num_samples, len(filtered_data)) + filtered_data = filtered_data[args.start_idx:end_idx] + else: + filtered_data = filtered_data[args.start_idx:] + + logger.info(f"Final dataset size: {len(filtered_data)} samples") + + # Determine chat template + chat_template = determine_chat_template(args.model_path, args.chat_template) + logger.info(f"Using chat template: {chat_template}") + + logger.info(f"Prepare and truncate prompts...") + # Select appropriate template + if args.no_context: + template_key = '0shot_no_context' + elif args.rag > 0: + template_key = '0shot_rag' + elif args.cot: + template_key = '0shot_cot' + else: + template_key = '0shot' + + template = config['templates'][template_key] + logger.info(f"Using template: {template_key}") + + # Set up extra end token ids + extra_end_token_ids = [] + if chat_template == "llama3": + eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0] + extra_end_token_ids.append(eot_id) + logger.info(f"Added llama3 end token: {eot_id}") + + if chat_template == "qwen": + im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0] + extra_end_token_ids.append(im_end_id) + logger.info(f"Added qwen end token: {im_end_id}") + + # Prepare prompts + prompts = [] + for sample in filtered_data: + formatted_prompt = format_prompt(sample, template, args) + + # Apply chat formatting if needed + if chat_template != 'none': + formatted_prompt = build_chat(tokenizer, formatted_prompt, + chat_template) + + # Apply truncation if prompt is too long + formatted_prompt = truncate_prompt(formatted_prompt, tokenizer, + args.max_len) + + prompts.append(formatted_prompt) + + if len(prompts) == 0: + logger.warning(f"No prompts to evaluate") + return [], 0.0 + + # Run inference + logger.info(f"Starting inference for {len(prompts)} samples...") + start_time = time.time() + + # Set sampling parameters + max_new_tokens = 1024 if args.cot else 256 + sampling_params = SamplingParams( + max_tokens=max_new_tokens, + temperature=0.1, + top_p=0.95, + stop_token_ids=extra_end_token_ids if extra_end_token_ids else None, + ) + + outputs = llm.generate(prompts, sampling_params) + + inference_time = time.time() - start_time + logger.info( + f"Inference completed in {inference_time:.2f} seconds, average time per sample: {inference_time/len(prompts):.3f} seconds" + ) + + # Process outputs + results = [] + for i, (sample, output) in enumerate(zip(filtered_data, outputs)): + prediction = output.outputs[0].text.strip() + processed_prediction = post_process(prediction, chat_template) + + # Handle CoT mode + if args.cot: + # For CoT, we need to do a second inference to extract the final answer + cot_response = processed_prediction + + # Format the CoT answer extraction prompt + cot_ans_template = config['templates']['0shot_cot_ans'] + cot_ans_prompt = format_prompt(sample, cot_ans_template, args) + cot_ans_prompt = cot_ans_prompt.replace('$COT$', cot_response) + + if chat_template != 'none': + cot_ans_prompt = build_chat(tokenizer, cot_ans_prompt, + chat_template) + + # Apply truncation to CoT answer extraction prompt + cot_ans_prompt = truncate_prompt(cot_ans_prompt, tokenizer, + args.max_len) + + # Generate final answer + ans_sampling_params = SamplingParams( + max_tokens=128, + temperature=0.1, + top_p=0.95, + stop_token_ids=extra_end_token_ids + if extra_end_token_ids else None, + ) + + ans_output = llm.generate([cot_ans_prompt], ans_sampling_params)[0] + final_prediction = post_process(ans_output.outputs[0].text.strip(), + chat_template) + + extracted_answer = extract_answer(final_prediction) + else: + extracted_answer = extract_answer(processed_prediction) + + # Calculate accuracy + is_correct = extracted_answer == sample[ + 'answer'] if extracted_answer else False + + result = { + '_id': sample['_id'], + 'domain': sample['domain'], + 'sub_domain': sample['sub_domain'], + 'difficulty': sample['difficulty'], + 'length': sample['length'], + 'question': sample['question'], + 'choice_A': sample['choice_A'], + 'choice_B': sample['choice_B'], + 'choice_C': sample['choice_C'], + 'choice_D': sample['choice_D'], + 'answer': sample['answer'], + 'prediction': processed_prediction, + 'extracted_answer': extracted_answer, + 'is_correct': is_correct, + 'context_length': len(sample['context']), + 'prompt_length': len(output.prompt_token_ids), + 'output_length': len(output.outputs[0].token_ids), + } + + if args.cot: + result['cot_response'] = cot_response + result['final_prediction'] = final_prediction + + results.append(result) + + return results, inference_time + + +def calculate_metrics(results: List[Dict]) -> Dict[str, Any]: + """Calculate evaluation metrics for LongBench v2.""" + if not results: + return {} + + total_samples = len(results) + correct_samples = sum(1 for r in results if r['is_correct']) + overall_accuracy = correct_samples / total_samples + + metrics = { + 'overall_accuracy': round(overall_accuracy * 100, 2), + 'total_samples': total_samples, + 'correct_samples': correct_samples + } + + # Calculate metrics by difficulty + difficulties = ['easy', 'hard'] + for difficulty in difficulties: + diff_results = [r for r in results if r['difficulty'] == difficulty] + if diff_results: + diff_correct = sum(1 for r in diff_results if r['is_correct']) + metrics[f'{difficulty}_accuracy'] = round( + (diff_correct / len(diff_results)) * 100, 2) + + # Calculate metrics by length + lengths = ['short', 'medium', 'long'] + for length in lengths: + len_results = [r for r in results if r['length'] == length] + if len_results: + len_correct = sum(1 for r in len_results if r['is_correct']) + metrics[f'{length}_accuracy'] = round( + (len_correct / len(len_results)) * 100, 2) + + # Calculate metrics by domain + domains = list(set(r['domain'] for r in results)) + for domain in domains: + domain_results = [r for r in results if r['domain'] == domain] + if domain_results: + domain_correct = sum(1 for r in domain_results if r['is_correct']) + metrics[f'{domain}_accuracy'] = round( + (domain_correct / len(domain_results)) * 100, 2) + + return metrics + + +def save_results(results: List[Dict], args: argparse.Namespace, + inference_time: float, output_dir: str): + """Save evaluation results in format compatible with LongBench v2.""" + os.makedirs(output_dir, exist_ok=True) + + # Calculate metrics + metrics = calculate_metrics(results) + logger.info(f"Evaluation metrics: {metrics}") + + # Save detailed results + results_file = os.path.join(output_dir, "longbench_v2_results.jsonl") + with open(results_file, 'w', encoding='utf-8') as f: + for result in results: + json.dump(result, f, ensure_ascii=False) + f.write('\n') + + # Save prediction results in LongBench v2 format + pred_file = os.path.join(output_dir, "predictions.jsonl") + with open(pred_file, 'w', encoding='utf-8') as f: + for result in results: + pred_data = { + "_id": result['_id'], + "prediction": result['extracted_answer'], + "response": result['prediction'], + "judge": result['is_correct'] + } + if args.cot: + pred_data['cot_response'] = result.get('cot_response', '') + pred_data['final_prediction'] = result.get( + 'final_prediction', '') + + json.dump(pred_data, f, ensure_ascii=False) + f.write('\n') + + # Create summary + summary = { + 'experiment_config': { + 'model_path': args.model_path, + 'attention_backend': args.attention_backend, + 'rocket_sparse': args.rocket_sparse, + 'token_budget': args.token_budget, + 'cot': args.cot, + 'no_context': args.no_context, + 'rag': args.rag, + 'difficulty_filter': args.difficulty, + 'length_filter': args.length, + 'domain_filter': args.domain, + 'max_seq_len': args.max_seq_len, + 'max_new_tokens': args.max_new_tokens + }, + 'evaluation_results': metrics, + 'timing': { + 'total_inference_time': inference_time, + 'avg_inference_time': + inference_time / len(results) if results else 0, + 'evaluation_timestamp': datetime.now().isoformat() + } + } + + # Save summary + summary_file = os.path.join(output_dir, "summary.json") + with open(summary_file, 'w', encoding='utf-8') as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + + logger.info(f"Results saved to {output_dir}") + + return metrics + + +def main(): + """Main evaluation function.""" + args = parse_arguments() + logger.set_level(args.log_level) + + # Setup experiment name + if not args.exp_name: + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + model_name = os.path.basename(args.model_path).replace('/', '_') + args.exp_name = f"longbench_v2_{model_name}_{timestamp}" + + output_dir = os.path.join(args.output_dir, args.exp_name) + + logger.info( + "=========== LongBench v2 Evaluation with TensorRT-LLM ===========") + + os.makedirs(output_dir, exist_ok=True) + + # Save configuration + config_file = os.path.join(output_dir, "config.json") + with open(config_file, 'w') as f: + json.dump(vars(args), f, indent=2) + logger.info(f"Configuration saved to {config_file}") + + # Initialize LLM and tokenizer + llm, tokenizer = initialize_llm(args) + + # Run evaluation + logger.info(f"Starting LongBench v2 evaluation...") + results, inference_time = evaluate_longbench_v2(llm, tokenizer, args) + + # Save results and get metrics + metrics = save_results(results, args, inference_time, output_dir) + + logger.info(f"{'-'*80}") + logger.info(f"Evaluation completed successfully!") + logger.info(f"Total time: {inference_time:.2f} seconds") + logger.info(f"Total samples: {len(results)}") + + if metrics: + logger.info( + f"Overall accuracy: {metrics.get('overall_accuracy', 'N/A')}%") + + if 'easy_accuracy' in metrics: + logger.info( + f"Easy difficulty accuracy: {metrics['easy_accuracy']}% ({metrics.get('easy_samples', 0)} samples)" + ) + if 'hard_accuracy' in metrics: + logger.info( + f"Hard difficulty accuracy: {metrics['hard_accuracy']}% ({metrics.get('hard_samples', 0)} samples)" + ) + + for length in ['short', 'medium', 'long']: + if f'{length}_accuracy' in metrics: + logger.info( + f"{length.capitalize()} length accuracy: {metrics[f'{length}_accuracy']}% ({metrics.get(f'{length}_samples', 0)} samples)" + ) + + +if __name__ == '__main__': + main() diff --git a/examples/longbench/requirements.txt b/examples/longbench/requirements.txt new file mode 100644 index 0000000000..fcc5fa8634 --- /dev/null +++ b/examples/longbench/requirements.txt @@ -0,0 +1,3 @@ +jieba +fuzzywuzzy +rouge diff --git a/tensorrt_llm/_torch/attention_backend/__init__.py b/tensorrt_llm/_torch/attention_backend/__init__.py index 68805efa4a..2d4574230b 100644 --- a/tensorrt_llm/_torch/attention_backend/__init__.py +++ b/tensorrt_llm/_torch/attention_backend/__init__.py @@ -1,5 +1,6 @@ from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from .interface import AttentionBackend, AttentionMetadata +from .sparse import get_sparse_attn_kv_cache_manager from .trtllm import AttentionInputType, TrtllmAttention, TrtllmAttentionMetadata from .vanilla import VanillaAttention, VanillaAttentionMetadata @@ -11,6 +12,7 @@ __all__ = [ "TrtllmAttentionMetadata", "VanillaAttention", "VanillaAttentionMetadata", + "get_sparse_attn_kv_cache_manager", ] if IS_FLASHINFER_AVAILABLE: diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index cdbe7c8c97..3104ab005c 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -140,6 +140,7 @@ class AttentionMetadata: _saved_tensors: Dict[str, torch.Tensor] = field(init=False, default_factory=dict) + sparse_attention_config: Optional["SparseAttentionConfig"] = None def __post_init__(self) -> None: if self.is_cross: @@ -563,6 +564,7 @@ class AttentionBackend(Generic[TMetadata]): num_kv_heads: Optional[int] = None, quant_config: Optional[QuantConfig] = None, skip_create_weights_in_init: bool = False, + sparse_attention_config: Optional["SparseAttentionConfig"] = None, **kwargs, ): """ @@ -573,12 +575,14 @@ class AttentionBackend(Generic[TMetadata]): head_dim (int): The size of each attention head (hidden_size // num_heads). num_kv_heads (int): The number of kv heads. Defaults to num_heads if None. quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied. + sparse_attention_config (SparseAttentionConfig): Optional sparse attention configuration. If None, no sparse attention is applied. """ self.layer_idx = layer_idx self.num_heads = num_heads self.head_dim = head_dim self.num_kv_heads = num_kv_heads or self.num_heads self.quant_config = quant_config + self.sparse_attention_config = sparse_attention_config def update_quant_config(self, new_quant_config: Optional[QuantConfig]): """ diff --git a/tensorrt_llm/_torch/attention_backend/sparse/__init__.py b/tensorrt_llm/_torch/attention_backend/sparse/__init__.py new file mode 100644 index 0000000000..f293f95475 --- /dev/null +++ b/tensorrt_llm/_torch/attention_backend/sparse/__init__.py @@ -0,0 +1,11 @@ +from .utils import (get_flashinfer_sparse_attn_attention_backend, + get_sparse_attn_kv_cache_manager, + get_trtllm_sparse_attn_attention_backend, + get_vanilla_sparse_attn_attention_backend) + +__all__ = [ + "get_sparse_attn_kv_cache_manager", + "get_vanilla_sparse_attn_attention_backend", + "get_trtllm_sparse_attn_attention_backend", + "get_flashinfer_sparse_attn_attention_backend", +] diff --git a/tensorrt_llm/_torch/attention_backend/sparse/kernel.py b/tensorrt_llm/_torch/attention_backend/sparse/kernel.py new file mode 100644 index 0000000000..695e44a592 --- /dev/null +++ b/tensorrt_llm/_torch/attention_backend/sparse/kernel.py @@ -0,0 +1,308 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _index_gather_kernel(output_ptr, input_ptr, index_ptr, in_row_stride, + in_token_stride, in_head_stride, idx_row_stride, + idx_token_stride, idx_head_stride, dim_size, + BLOCK_SIZE: tl.constexpr): + # get program id and block offset + row_pid = tl.program_id(0) + token_pid = tl.program_id(1) + head_pid = tl.program_id(2) + token_block_num = tl.num_programs(1) + head_num = tl.num_programs(2) + + # get index + indices_idx = row_pid * idx_row_stride + token_pid * idx_token_stride + head_pid * idx_head_stride + token_idx = tl.load(index_ptr + indices_idx) + + # get input and output base address + input_base = (row_pid * in_row_stride + token_idx * in_token_stride + + head_pid * in_head_stride) + output_base = (row_pid * token_block_num * head_num * dim_size + + token_pid * head_num * dim_size + head_pid * dim_size) + + # process elements in blocks + for dim_offset in tl.range(0, dim_size, BLOCK_SIZE): + # get offsets + offsets = tl.arange(0, BLOCK_SIZE) + dim_indices = dim_offset + offsets + mask = dim_indices < dim_size + + # load input and store output + input_val = tl.load(input_ptr + input_base + dim_indices, + mask=mask, + other=0.0) + tl.store(output_ptr + output_base + dim_indices, input_val, mask=mask) + + +def triton_index_gather(input, indices): + assert input.ndim == 4, "Input must be a 4D tensor, [row, token, head, dim]" + assert indices.ndim == 3, "Indices must be a 3D tensor, [row, token, head]" + + # shape of input and indices + row_size = input.shape[0] + head_num = input.shape[2] + dim_size = input.shape[3] + num_tokens = indices.shape[1] + + # create output tensor + output = torch.empty((row_size, num_tokens, head_num, dim_size), + device='cuda', + dtype=input.dtype) + + # launch kernel + grid = (row_size, num_tokens, head_num) + _index_gather_kernel[grid](output, + input, + indices, + input.stride(0), + input.stride(1), + input.stride(2), + indices.stride(0), + indices.stride(1), + indices.stride(2), + dim_size, + BLOCK_SIZE=1024) + return output + + +@triton.jit +def _update_kt_cache_ctx_kernel(k_ptr, cache_ptr, block_offsets_ptr, + cum_seq_lens_ptr, cum_kt_seq_lens_ptr, + token_to_batch_map_ptr, num_kv_heads, dim_size, + kt_page_size, tokens_per_block, + max_kt_blocks_per_seq, + BLOCK_SIZE: tl.constexpr): + # get program id + kt_token_idx = tl.program_id(0) + + # get params + batch_idx = tl.load(token_to_batch_map_ptr + kt_token_idx) + kv_start_idx = tl.load(cum_seq_lens_ptr + batch_idx) + kv_end_idx = tl.load(cum_seq_lens_ptr + batch_idx + 1) + kt_start_idx = tl.load(cum_kt_seq_lens_ptr + batch_idx) + local_kt_token_idx = kt_token_idx - kt_start_idx + global_kv_token_idx = kv_start_idx + local_kt_token_idx * kt_page_size + + # get offsets + hidden_size = num_kv_heads * dim_size + k_base = k_ptr + global_kv_token_idx * hidden_size + block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block + block_idx = tl.load(block_offsets_ptr + block_offset) + token_idx_in_block = local_kt_token_idx % tokens_per_block + cache_base = cache_ptr + (block_idx * tokens_per_block + + token_idx_in_block) * hidden_size * 2 + + # compute min/max and store kt + for hidden_start in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE) + head_idx = hidden_indices // dim_size + dim_idx = hidden_indices % dim_size + dim_mask = hidden_indices < hidden_size + + # get k_min and k_max + k_min = tl.full((BLOCK_SIZE, ), float('inf'), dtype=tl.float32) + k_max = tl.full((BLOCK_SIZE, ), float('-inf'), dtype=tl.float32) + for i in range(kt_page_size): + if global_kv_token_idx + i < kv_end_idx: + k = tl.load(k_base + i * hidden_size + hidden_indices, + mask=dim_mask, + other=0.0) + k_min = tl.minimum(k_min, k) + k_max = tl.maximum(k_max, k) + k_min = k_min.to(cache_ptr.dtype.element_ty) + k_max = k_max.to(cache_ptr.dtype.element_ty) + + # store k_min and k_max to cache + k_min_offset = cache_base + head_idx * dim_size * 2 + dim_idx + k_max_offset = k_min_offset + dim_size + tl.store(k_min_offset, k_min, mask=dim_mask) + tl.store(k_max_offset, k_max, mask=dim_mask) + + +@triton.jit +def _update_kt_cache_gen_kernel(k_ptr, cache_ptr, block_offsets_ptr, + seq_lens_ptr, num_kv_heads, dim_size, + kt_page_size, tokens_per_block, + max_kt_blocks_per_seq, + BLOCK_SIZE: tl.constexpr): + # get program id + batch_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + # get params + past_key_value_length = tl.load(seq_lens_ptr + batch_idx) - 1 + kt_token_idx = past_key_value_length // kt_page_size + kt_token_idx_in_page = past_key_value_length % kt_page_size + + # get offsets + hidden_size = num_kv_heads * dim_size + k_base = k_ptr + batch_idx * hidden_size + head_idx * dim_size + block_offset = batch_idx * max_kt_blocks_per_seq + kt_token_idx // tokens_per_block + block_idx = tl.load(block_offsets_ptr + block_offset) + kt_token_idx_in_block = kt_token_idx % tokens_per_block + cache_base = cache_ptr + (block_idx * tokens_per_block + + kt_token_idx_in_block) * hidden_size * 2 + cache_base += head_idx * dim_size * 2 + + # update kt cache + for hidden_start in tl.range(0, dim_size, BLOCK_SIZE): + hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE) + dim_mask = hidden_indices < dim_size + + # load k + k = tl.load(k_base + hidden_indices, mask=dim_mask, other=0.0) + + # load kt cache + kt_mask = dim_mask & (kt_token_idx_in_page > 0) + k_min = tl.load(cache_base + hidden_indices, + mask=kt_mask, + other=float('inf')) + k_max = tl.load(cache_base + hidden_indices + dim_size, + mask=kt_mask, + other=float('-inf')) + k_min = tl.minimum(k_min, k) + k_max = tl.maximum(k_max, k) + k_min = k_min.to(cache_ptr.dtype.element_ty) + k_max = k_max.to(cache_ptr.dtype.element_ty) + + # store k_min and k_max to cache + tl.store(cache_base + hidden_indices, k_min, mask=dim_mask) + tl.store(cache_base + hidden_indices + dim_size, k_max, mask=dim_mask) + + +@triton.jit +def _load_kt_cache_kernel(kt_states_ptr, cache_ptr, block_offsets_ptr, + cum_kt_seq_lens_ptr, token_to_batch_map_ptr, + num_kv_heads, dim_size, tokens_per_block, + max_kt_blocks_per_seq, BLOCK_SIZE: tl.constexpr): + # get program id + kt_token_idx = tl.program_id(0) + + # get params + batch_idx = tl.load(token_to_batch_map_ptr + kt_token_idx) + kt_start_idx = tl.load(cum_kt_seq_lens_ptr + batch_idx) + local_kt_token_idx = kt_token_idx - kt_start_idx + + # get offsets + hidden_size = num_kv_heads * dim_size * 2 + kt_states_base = kt_states_ptr + kt_token_idx * hidden_size + block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block + block_idx = tl.load(block_offsets_ptr + block_offset) + token_idx_in_block = local_kt_token_idx % tokens_per_block + cache_base = cache_ptr + (block_idx * tokens_per_block + + token_idx_in_block) * hidden_size + + # load kt cache + for hidden_start in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE) + mask = hidden_indices < hidden_size + # load kt cache + kt = tl.load(cache_base + hidden_indices, mask=mask, other=0.0) + # store kt to kt_states + tl.store(kt_states_base + hidden_indices, kt, mask=mask) + + +def triton_update_kt_cache(k, + kt_cache_tensor, + kt_cache_block_offsets, + seq_lens, + kt_page_size, + tokens_per_block, + max_kt_blocks_per_seq, + update=True): + # inputs: + # k: (total_seq_len, num_kv_heads, head_dim) + # kt_cache_tensor: (num_blocks, tokens_per_block, num_kv_heads, 2 * head_dim) + # kt_cache_block_offsets: (max_batch_size, max_kt_blocks_per_seq) + # seq_lens: (batch_size) + # kt_page_size: int + # update: bool + + # outputs: + # kt_states: (total_kt_tokens, num_kv_heads, 2 * head_dim) + + # params + batch_size = seq_lens.size(0) + num_kv_heads = k.size(1) + head_dim = k.size(2) + tokens_per_block = kt_cache_tensor.size(1) + num_kt_tokens = (seq_lens + kt_page_size - 1) // kt_page_size + + # context + if not update: + total_num_kt_tokens = num_kt_tokens.sum().item() + cum_seq_lens = torch.cumsum(torch.cat([ + torch.zeros(1, device='cuda', dtype=torch.long), + seq_lens.to(torch.long) + ]), + dim=0) + cum_kt_seq_lens = torch.cumsum(torch.cat([ + torch.zeros(1, device='cuda', dtype=torch.long), + num_kt_tokens.to(torch.long) + ]), + dim=0) + + token_to_batch_map = torch.repeat_interleave( + torch.arange(batch_size, + device='cuda'), repeats=num_kt_tokens).to(torch.long) + grid = (total_num_kt_tokens, ) + _update_kt_cache_ctx_kernel[grid](k, + kt_cache_tensor, + kt_cache_block_offsets, + cum_seq_lens, + cum_kt_seq_lens, + token_to_batch_map, + num_kv_heads, + head_dim, + kt_page_size, + tokens_per_block, + max_kt_blocks_per_seq, + BLOCK_SIZE=1024) + return + else: + # generation + # update kt cache + grid = (batch_size, num_kv_heads) + _update_kt_cache_gen_kernel[grid](k, + kt_cache_tensor, + kt_cache_block_offsets, + seq_lens, + num_kv_heads, + head_dim, + kt_page_size, + tokens_per_block, + max_kt_blocks_per_seq, + BLOCK_SIZE=1024) + + # load kt cache + total_num_kt_tokens = num_kt_tokens.sum().item() + kt_states = torch.empty( + (total_num_kt_tokens, num_kv_heads, 2 * head_dim), + device='cuda', + dtype=k.dtype) + token_to_batch_map = torch.repeat_interleave( + torch.arange(batch_size, + device='cuda'), repeats=num_kt_tokens).to(torch.long) + cum_kt_seq_lens = torch.cumsum(torch.cat([ + torch.zeros(1, device='cuda', dtype=torch.long), + num_kt_tokens.to(torch.long) + ]), + dim=0) + grid = (total_num_kt_tokens, ) + _load_kt_cache_kernel[grid](kt_states, + kt_cache_tensor, + kt_cache_block_offsets, + cum_kt_seq_lens, + token_to_batch_map, + num_kv_heads, + head_dim, + tokens_per_block, + max_kt_blocks_per_seq, + BLOCK_SIZE=1024) + + return kt_states diff --git a/tensorrt_llm/_torch/attention_backend/sparse/rocket.py b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py new file mode 100644 index 0000000000..9c8970ec41 --- /dev/null +++ b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py @@ -0,0 +1,1061 @@ +import math +from collections import deque +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from triton import next_power_of_2 + +import tensorrt_llm +import tensorrt_llm.bindings +from tensorrt_llm._torch.attention_backend.trtllm import ( + TrtllmAttention, TrtllmAttentionMetadata) +from tensorrt_llm._torch.attention_backend.vanilla import ( + VanillaAttention, VanillaAttentionMetadata, repeat_kv) +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm._utils import get_size_in_bytes +from tensorrt_llm.bindings import DataType +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.bindings.internal.batch_manager import \ + CacheType as CacheTypeCpp +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantConfig + +from .kernel import triton_index_gather, triton_update_kt_cache + +ModelConfig = tensorrt_llm.bindings.ModelConfig + + +class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata): + + def __post_init__(self): + super().__post_init__() + if self.sparse_attention_config is None: + raise ValueError("Sparse attention config is not set") + self.prompt_budget = self.sparse_attention_config.prompt_budget + self.kt_cache_block_offsets = torch.empty( + [ + self.max_num_sequences, + self.kv_cache_manager.max_kt_blocks_per_seq + ], + dtype=torch.int32, + device='cuda', + ) + + @property + def kt_tokens_per_block(self) -> Optional[int]: + """ + Returns the number of kt tokens per block from the KV cache manager. + """ + return self.kv_cache_manager.kt_tokens_per_block if self.kv_cache_manager is not None else None + + def prepare(self): + if self.kv_cache_manager is not None: + num_contexts = self.num_contexts + num_generations = self.num_generations + num_requests = num_contexts + num_generations + + for i in range(num_requests): + if i < num_contexts: + self.kv_cache_params.num_cached_tokens_per_seq[i] = 0 + else: + if self.prompt_lens[i] > self.prompt_budget: + self.kv_cache_params.num_cached_tokens_per_seq[ + i] += self.prompt_budget - self.prompt_lens[i] + + super().prepare() + + # Update prompt lens for sparse attention + if self.kv_cache_manager is not None: + _prompt_lens = self.prompt_lens.copy() + for i in range(num_requests): + if i >= num_contexts: + _prompt_lens[i] = min(_prompt_lens[i], self.prompt_budget) + _prompt_lens = torch.tensor(_prompt_lens, + dtype=torch.int, + device='cpu') + self.prompt_lens_cpu[:self.num_seqs].copy_(_prompt_lens) + self.prompt_lens_cuda[:self.num_seqs].copy_( + self.prompt_lens_cpu[:self.num_seqs], non_blocking=True) + self.prompt_lens_cuda_runtime = self.prompt_lens_cuda[:self. + num_seqs] + self.prompt_lens_cpu_runtime = self.prompt_lens_cpu[:self.num_seqs] + + # for kt cache + self.host_kt_cache_block_offsets = self.kv_cache_manager.get_kt_block_offsets( + self.request_ids) + self.kt_cache_block_offsets[:self.num_seqs].copy_( + self.host_kt_cache_block_offsets[:self.num_seqs], + non_blocking=True) + + +@torch.compile(dynamic=True) +def convert_token_to_page_sparse_indices( + sparse_attn_indices: torch.Tensor, sparse_attn_offsets: torch.Tensor, + metadata: 'TrtllmAttentionMetadata' +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert token-based sparse attention indices to page-based sparse attention indices. + + Args: + sparse_attn_indices: Token-based indices with shape [num_tokens, num_kv_heads] + sparse_attn_offsets: Offsets with shape [batch_size+1] indicating token boundaries for each batch + metadata: Attention metadata containing tokens_per_block (page_size) + + Returns: + Tuple of (page_indices, page_offsets): + - page_indices: Page-based indices with shape [num_pages, num_kv_heads] + - page_offsets: Updated offsets with shape [batch_size+1] indicating page boundaries for each batch + + Example: + If sparse_attn_indices first dimension is [1, 30, 67] and page_size=32, + the result will be [0, 2] (token 1 -> page 0, token 30 -> page 0, token 67 -> page 2) + """ + page_size = metadata.tokens_per_block + batch_size = sparse_attn_offsets.size(0) - 1 + num_kv_heads = sparse_attn_indices.size(1) + + # Convert token indices to page indices + page_indices = sparse_attn_indices // page_size + + # Process each batch and each kv_head separately to remove duplicates + new_page_indices_list = [] + new_offsets = torch.zeros_like(sparse_attn_offsets) + + current_offset = 0 + for batch_idx in range(batch_size): + start_idx = sparse_attn_offsets[batch_idx] + end_idx = sparse_attn_offsets[batch_idx + 1] + + if start_idx >= end_idx: + # Empty batch + new_offsets[batch_idx + 1] = current_offset + continue + + batch_page_indices = page_indices[ + start_idx:end_idx] # [num_tokens_in_batch, num_kv_heads] + + # For each kv_head, remove duplicates while preserving order + batch_unique_pages = [] + for head_idx in range(num_kv_heads): + head_pages = batch_page_indices[:, head_idx] + unique_pages = torch.unique(head_pages, sorted=False) + batch_unique_pages.append(unique_pages) + + # Find the maximum length among all heads for this batch + max_len = max(pages.size(0) for pages in batch_unique_pages) + + if max_len > 0: + batch_result = torch.full((max_len, num_kv_heads), + fill_value=-1, + dtype=page_indices.dtype, + device=page_indices.device) + + for head_idx in range(num_kv_heads): + unique_pages = batch_unique_pages[head_idx] + batch_result[:unique_pages.size(0), head_idx] = unique_pages + + new_page_indices_list.append(batch_result) + current_offset += max_len + + new_offsets[batch_idx + 1] = current_offset + + new_page_indices = torch.cat(new_page_indices_list, dim=0) + + return new_page_indices, new_offsets + + +class RocketTrtllmAttention(TrtllmAttention): + Metadata = RocketTrtllmAttentionMetadata + + # Access type for different dtype sizes + _access_type = { + 1: torch.int8, + 2: torch.int16, + 4: torch.int32, + 8: torch.int64 + } + + def __init__( + self, + layer_idx: int, + num_heads: int, + head_dim: int, + num_kv_heads: Optional[int] = None, + quant_config: Optional[QuantConfig] = None, + q_scaling: Optional[float] = None, + sparse_attention_config: Optional["SparseAttentionConfig"] = None, + **kwargs): + super().__init__(layer_idx, + num_heads, + head_dim, + sparse_attention_config=sparse_attention_config, + num_kv_heads=num_kv_heads, + quant_config=quant_config, + q_scaling=q_scaling, + **kwargs) + self.topr = sparse_attention_config.topr + self.topk = sparse_attention_config.topk + self.prompt_budget = sparse_attention_config.prompt_budget + self.window_size = sparse_attention_config.window_size + self.kernel_size = sparse_attention_config.kernel_size + self.page_size = sparse_attention_config.page_size + + def sparse_attn_predict( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + metadata: TrtllmAttentionMetadata, + **kwargs, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Predict sparse attention indices. + For RocketKV: + - Generation phase: predict RocketKV sparse attention indices + + Returns: + - sparse_attn_indices: [total_selected_indices, num_kv_heads] + - sparse_attn_offsets: [batch_size + 1] with cumulative indices count + """ + if k is None: + q, k, _ = q.split([ + self.num_heads * self.head_dim, self.num_kv_heads * + self.head_dim, self.num_kv_heads * self.head_dim + ], + dim=-1) + + num_contexts = metadata.num_contexts + num_generations = metadata.num_generations + seq_lens = metadata.seq_lens + seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens + past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq + + sparse_attn_indices = [] + sparse_attn_offsets = [0] + + q_offset = 0 + k_offset = 0 + + for i in range(num_contexts + num_generations): + seq_len = seq_lens[i].item() + seq_len_kv = seq_lens_kv[i].item() + + if seq_len <= 0 or seq_len_kv <= 0: + assert False, "Invalid sequence length" + + single_q = q[q_offset:q_offset + seq_len] + single_k = k[k_offset:k_offset + seq_len_kv] + + single_q = single_q.view(1, seq_len, self.num_heads, + self.head_dim).transpose(1, 2) + single_k = single_k.view(1, seq_len_kv, self.num_kv_heads, + self.head_dim) + + past_seen_token = past_seen_tokens[i] + # Generation phase: RocketKV sparse attention indices + if i >= num_contexts: + _sparse_attn_indices = self._rocketkv_selection( + single_q, single_k, past_seen_token, metadata, i) + if _sparse_attn_indices is not None: + sparse_attn_indices.append( + _sparse_attn_indices.squeeze(0)) # [topk, num_kv_heads] + sparse_attn_offsets.append(sparse_attn_offsets[-1] + + _sparse_attn_indices.size(1)) + else: + sparse_attn_offsets.append(sparse_attn_offsets[-1]) + + q_offset += seq_len + k_offset += seq_len_kv + + if len(sparse_attn_indices) == 0: + sparse_attn_indices, sparse_attn_offsets = None, None + else: + sparse_attn_indices = torch.cat(sparse_attn_indices, + dim=0).to(torch.int32) + sparse_attn_offsets = torch.tensor(sparse_attn_offsets, + dtype=torch.int32).to(q.device) + sparse_attn_indices, sparse_attn_offsets = convert_token_to_page_sparse_indices( + sparse_attn_indices, sparse_attn_offsets, metadata) + sparse_attn_indices = sparse_attn_indices.transpose(0, + 1).contiguous() + return sparse_attn_indices, sparse_attn_offsets + + def sparse_kv_predict( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + metadata: TrtllmAttentionMetadata, + **kwargs, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Predict sparse kv indices. + + For RocketKV: + - Context phase: predict RocketKV sparse kv indices + + Returns: + - flattened_indices: [total_selected_indices, num_kv_heads] + - batch_offsets: [batch_size + 1] with cumulative indices count + """ + if k is None: + q, k, _ = q.split([ + self.num_heads * self.head_dim, self.num_kv_heads * + self.head_dim, self.num_kv_heads * self.head_dim + ], + dim=-1) + + num_contexts = metadata.num_contexts + num_generations = metadata.num_generations + seq_lens = metadata.seq_lens + seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens + past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq + + sparse_kv_indices = [] + sparse_kv_offsets = [0] + + q_offset = 0 + k_offset = 0 + + for i in range(num_contexts + num_generations): + seq_len = seq_lens[i].item() + seq_len_kv = seq_lens_kv[i].item() + + if seq_len <= 0 or seq_len_kv <= 0: + assert False, "Invalid sequence length" + + single_q = q[q_offset:q_offset + seq_len] + single_k = k[k_offset:k_offset + seq_len_kv] + + single_q = single_q.view(1, seq_len, self.num_heads, + self.head_dim).transpose(1, 2) + single_k = single_k.view(1, seq_len_kv, self.num_kv_heads, + self.head_dim) + + past_seen_token = past_seen_tokens[i] + if i < num_contexts: + # Context phase: SnapKV sparse kv indices + _sparse_kv_indices = self._get_snapkv_indices( + single_q, single_k, past_seen_token, metadata, i) + if _sparse_kv_indices is not None: + sparse_kv_indices.append( + _sparse_kv_indices.squeeze(0)) # [budget, num_kv_heads] + sparse_kv_offsets.append(sparse_kv_offsets[-1] + + _sparse_kv_indices.size(1)) + else: + sparse_kv_offsets.append(sparse_kv_offsets[-1]) + + q_offset += seq_len + k_offset += seq_len_kv + + if len(sparse_kv_indices) == 0: + sparse_kv_indices, sparse_kv_offsets = None, None + else: + sparse_kv_indices = torch.cat(sparse_kv_indices, + dim=0).to(torch.int32) + sparse_kv_indices = sparse_kv_indices.transpose(0, 1).contiguous() + sparse_kv_offsets = torch.tensor(sparse_kv_offsets, + dtype=torch.int32).to(q.device) + return sparse_kv_indices, sparse_kv_offsets + + def _get_snapkv_indices(self, q: Tensor, k: Tensor, past_seen_token: int, + metadata: RocketTrtllmAttentionMetadata, + sample_idx: int) -> Optional[Tensor]: + """ + Get SnapKV sparse kv indices from the input sequence for context phase. + The shape of output is (1, prompt_budget, num_kv_heads) + """ + bsz = 1 + seq_len = k.size(1) # k shape: (1, seq_len, num_kv_heads, head_dim) + + # If the sequence length is less than the prompt budget, do not enable sparse kv cache + if seq_len <= self.prompt_budget: + return None + + # Use last window_size tokens as observation + # (1, num_heads, window_size, head_dim) + q_obs = q[:, :, -self.window_size:] + # (1, num_kv_heads, seq_len, head_dim) + k_pre = repeat_kv(k.transpose(1, 2), + self.num_heads // self.num_kv_heads) + + dist = (torch.arange(0, self.window_size, device=q.device)[:, None] - + torch.arange(0, seq_len, device=q.device)[None, :] + seq_len - + self.window_size) + attention_mask = (dist >= 0) + + score = torch.matmul(q_obs, k_pre.transpose(-1, -2)) / math.sqrt( + self.head_dim) + + score = torch.masked_fill( + score, + attention_mask.view(1, 1, self.window_size, seq_len) == False, + torch.scalar_tensor(float("-inf"), + device=score.device, + dtype=score.dtype)) + + score = torch.nn.functional.softmax(score, dim=-1) + + score = torch.masked_fill( + score, + attention_mask.view(1, 1, self.window_size, seq_len) == False, + torch.scalar_tensor(0, device=score.device, dtype=score.dtype)) + + score = score[:, :, -self.window_size:, :-self.window_size].sum(dim=-2) + + score = score.view(bsz, self.num_kv_heads, + self.num_heads // self.num_kv_heads, -1).sum(dim=2) + score = torch.nn.functional.max_pool1d(score, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + stride=1) + + # Select top important tokens from prefix + prefix_len = seq_len - self.window_size + selected_prefix_indices = score.topk(self.prompt_budget - + self.window_size, + dim=-1).indices.sort().values + + # Combine selected prefix indices with window indices + window_indices = torch.arange( + prefix_len, seq_len, + device=k.device).unsqueeze(0).unsqueeze(0).expand( + bsz, self.num_kv_heads, -1) + selected_indices = torch.cat([selected_prefix_indices, window_indices], + dim=-1).transpose(1, 2) + + k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) + k_snap = triton_index_gather(k, selected_indices) + # Update KT cache + kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers( + self.layer_idx) + k_snap_len = torch.clamp( + metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1], + max=self.prompt_budget).int() + triton_update_kt_cache( + k_snap.squeeze(0).contiguous(), + kt_cache_tensor, + metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1], + k_snap_len, + self.page_size, + metadata.kt_tokens_per_block, + metadata.kv_cache_manager.max_kt_blocks_per_seq, + update=False) + + return selected_indices + + def _rocketkv_selection(self, q: Tensor, k: Tensor, past_seen_token: int, + metadata: RocketTrtllmAttentionMetadata, + sample_idx: int) -> Tensor: + """ + Implement RocketKV's two-stage selection process for generation phase. + The shape of output is (1, topk, num_kv_heads) + """ + bsz = 1 + q_len = q.size(2) + + # Helper functions + def _gather(t: Tensor, dim: int, i: Tensor) -> Tensor: + dim += (dim < 0) * t.ndim + return t.gather( + dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1:])) + + @torch.compile(disable=not torch.cuda.is_available()) + def _scaled_softmax(x: Tensor, divscale: Tensor | float, + dim: int) -> Tensor: + return torch.softmax(x / divscale, dim=dim) + + # Get KT cache for key-token matching + kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers( + self.layer_idx) + target_seq_len = past_seen_token + 1 # +1 for current token + + # Update KT cache + kt_states = triton_update_kt_cache( + k.squeeze(0).contiguous(), kt_cache_tensor, + metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1], + metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1], + self.page_size, metadata.kt_tokens_per_block, + metadata.kv_cache_manager.max_kt_blocks_per_seq) + kt_states = kt_states.unsqueeze(0).permute(0, 2, 3, 1) + + # Reshape query for multi-head processing + qi = q.view(bsz, self.num_kv_heads, self.num_heads // self.num_kv_heads, + q_len, self.head_dim) + qi_abs = torch.abs(qi) + + # Top-r selection on query features + i1 = torch.topk(qi_abs.mean(dim=2, keepdim=True), self.topr, + dim=-1).indices + qi_hat = _gather(qi, -1, i1) + + # Generate signed indices for key-token matching + i1_sign = torch.where( + qi_hat.sum(dim=2, keepdim=True) > 0, i1 + self.head_dim, + i1).transpose(-1, -2) + + # Gather key tokens and compute attention scores + kt_hat = _gather(kt_states.unsqueeze(2), -2, i1_sign) + qk_hat = qi_hat @ kt_hat + qk_hat = qk_hat.repeat_interleave(self.page_size, + dim=-1)[:, :, :, :, :target_seq_len] + scale = torch.sqrt(self.head_dim * + torch.abs(qi_hat).sum(dim=-1, keepdim=True) / + qi_abs.sum(dim=-1, keepdim=True)) + + # (1, num_kv_heads, num_heads, target_seq_len) + s_hat = _scaled_softmax(qk_hat, scale, dim=-1) + + topk = min(self.topk, target_seq_len) + i2 = torch.topk(s_hat.mean(dim=2, keepdim=True), topk, dim=-1).indices + + iKV = i2[:, :, 0, 0, :].transpose(1, 2) # (1, topk, num_kv_heads) + + return iKV + + +class RocketVanillaAttentionMetadata(VanillaAttentionMetadata): + + def __post_init__(self): + super().__post_init__() + if self.sparse_attention_config is None: + raise ValueError("Sparse attention config is not set") + self.prompt_budget = self.sparse_attention_config.prompt_budget + self.kt_cache_block_offsets = torch.empty( + [ + self.max_num_sequences, + self.kv_cache_manager.max_kt_blocks_per_seq + ], + dtype=torch.int32, + device='cuda', + ) + + def prepare(self) -> None: + super().prepare() + num_contexts = self.num_contexts + num_generations = self.num_generations + num_requests = num_contexts + num_generations + + for i in range(num_requests): + if i < num_contexts: + self.kv_cache_params.num_cached_tokens_per_seq[i] = 0 + else: + if self.prompt_lens[i] > self.prompt_budget: + self.kv_cache_params.num_cached_tokens_per_seq[ + i] += self.prompt_budget - self.prompt_lens[i] + + if self.kv_cache_manager is not None: + # for kt cache + self.host_kt_cache_block_offsets = self.kv_cache_manager.get_kt_block_offsets( + self.request_ids) + self.kt_cache_block_offsets[:self.num_seqs].copy_( + self.host_kt_cache_block_offsets[:self.num_seqs], + non_blocking=True) + + +class RocketVanillaAttention(VanillaAttention): + """ + RocketKV sparse attention implementation. + - Context phase: only support sparse kv cache + - Generation phase: only support sparse attention computation + """ + + Metadata = RocketVanillaAttentionMetadata + + def __init__( + self, + layer_idx: int, + num_heads: int, + head_dim: int, + num_kv_heads: Optional[int] = None, + quant_config: Optional[QuantConfig] = None, + q_scaling: Optional[float] = None, + sparse_attention_config: Optional["SparseAttentionConfig"] = None, + **kwargs): + super().__init__(layer_idx, + num_heads, + head_dim, + sparse_attention_config=sparse_attention_config, + num_kv_heads=num_kv_heads, + quant_config=quant_config, + q_scaling=q_scaling, + **kwargs) + self.topr = sparse_attention_config.topr + self.topk = sparse_attention_config.topk + self.prompt_budget = sparse_attention_config.prompt_budget + self.window_size = sparse_attention_config.window_size + self.kernel_size = sparse_attention_config.kernel_size + self.page_size = sparse_attention_config.page_size + + def _single_request_sparse_kv_predict( + self, q: Optional[Tensor], k: Optional[Tensor], v: Optional[Tensor], + metadata: RocketVanillaAttentionMetadata, past_seen_token: int, + sample_idx: int, **kwargs) -> tuple[Optional[torch.Tensor], int]: + """ + Predict KV indices for writing new key/value pairs. + For RocketKV: + - Context phase: return SnapKV sparse kv indices + - Generation phase: return None + """ + if k is None or v is None: + return None, 0 + + # Generation phase: do not support sparse kv cache + if past_seen_token > 0: + return None, k.size(1) + + # Context phase: predict SnapKV sparse kv indices + sparse_kv_indices = self._get_snapkv_indices(q, k, sample_idx) + + # Gather the key states using the sparse kv indices + if sparse_kv_indices is not None: + k_snap = triton_index_gather(k, sparse_kv_indices) + else: + k_snap = k + + # Update KT cache + kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers( + self.layer_idx) + target_seq_len = past_seen_token + k_snap.size(1) + kt_cache_position = torch.arange(past_seen_token // self.page_size, + math.ceil(target_seq_len / + self.page_size), + device=q.device) + self._single_request_update_kt_cache( + k_snap, + kt_cache_tensor, + metadata.kt_cache_block_offsets[sample_idx], + target_seq_len, + kt_cache_position, + update=False) + return sparse_kv_indices, k_snap.size(1) + + def _single_request_sparse_attn_predict( + self, q: Tensor, k: Optional[Tensor], v: Optional[Tensor], + kv_cache_tensor: Tensor, metadata: RocketVanillaAttentionMetadata, + past_seen_token: int, sample_idx: int, + **kwargs) -> tuple[Optional[torch.Tensor], int]: + """ + Predict KV cache indices for sparse attention computation. + For RocketKV: + - Context phase: returns None (use full attention) + - Generation phase: return RocketKV sparse indices for sparse attention computation + """ + if k is None or v is None: + return None, 0 + + # Context phase: use full attention + if past_seen_token == 0: + return None, k.size(1) + + # Get RocketKV sparse indices + sparse_indices = self._rocketkv_selection(q, k, metadata, + past_seen_token, sample_idx) + return sparse_indices, sparse_indices.size(1) + + def _get_snapkv_indices(self, q: Tensor, k: Tensor, + sample_idx: int) -> Tensor: + """Get SnapKV sparse kv indices from the input sequence for context phase.""" + bsz = 1 + seq_len = k.size(1) + + # If the sequence length is less than the prompt budget, do not enable sparse kv cache + if seq_len <= self.prompt_budget: + return None + + # Use last window_size tokens as observation + # (1, num_kv_heads, window_size, head_dim) + q_obs = q[:, :, -self.window_size:] + # (1, num_kv_heads, seq_len, head_dim) + k_pre = repeat_kv(k.transpose(1, 2), self.num_key_value_groups) + + dist = (torch.arange(0, self.window_size, device=q.device)[:, None] - + torch.arange(0, seq_len, device=q.device)[None, :] + seq_len - + self.window_size) + attention_mask = (dist >= 0) + + score = torch.matmul(q_obs, k_pre.transpose(-1, -2)) / math.sqrt( + self.head_dim) + + score = torch.masked_fill( + score, + attention_mask.view(1, 1, self.window_size, seq_len) == False, + torch.scalar_tensor(float("-inf"), + device=score.device, + dtype=score.dtype)) + + score = torch.nn.functional.softmax(score, dim=-1) + + score = torch.masked_fill( + score, + attention_mask.view(1, 1, self.window_size, seq_len) == False, + torch.scalar_tensor(0, device=score.device, dtype=score.dtype)) + + score = score[:, :, -self.window_size:, :-self.window_size].sum(dim=-2) + + score = score.view(bsz, self.num_kv_heads, self.num_key_value_groups, + -1).sum(dim=2) + score = torch.nn.functional.max_pool1d(score, + kernel_size=self.kernel_size, + padding=self.kernel_size // 2, + stride=1) + + # Select top important tokens from prefix + prefix_len = seq_len - self.window_size + selected_prefix_indices = score.topk(self.prompt_budget - + self.window_size, + dim=-1).indices.sort().values + + # Combine selected prefix indices with window indices + window_indices = torch.arange( + prefix_len, seq_len, + device=k.device).unsqueeze(0).unsqueeze(0).expand( + bsz, self.num_kv_heads, -1) + selected_indices = torch.cat([selected_prefix_indices, window_indices], + dim=-1) + + return selected_indices.transpose(1, 2) + + def _single_request_update_kt_cache(self, + k, + kt_cache_tensor, + kt_cache_block_offsets, + seq_len, + cache_position, + update=True): + """Update KT cache for RocketKV algorithm.""" + # (1, num_pages_per_block, num_kv_heads, 2*head_dim) + k_out = kt_cache_tensor[kt_cache_block_offsets[0], :, :, :].unsqueeze(0) + + # k: (1, seq_len, num_kv_heads, head_dim) + if k is not None: + padding_len = self.page_size - ( + (k.size(1) - 1) % self.page_size + 1) + padding_tensor = torch.full( + (1, padding_len, self.num_kv_heads, self.head_dim), + float('inf'), + device=k.device, + dtype=k.dtype) + # (1, seq_len+padding_len, num_kv_heads, head_dim) + k_min = torch.cat([k, padding_tensor], dim=1) + k_min = k_min.reshape(1, -1, self.page_size, self.num_kv_heads, + self.head_dim).amin(dim=2) + k_max = torch.cat([k, -padding_tensor], dim=1) + k_max = k_max.reshape(1, -1, self.page_size, self.num_kv_heads, + self.head_dim).amax(dim=2) + if update and (seq_len - 1) % self.page_size > 0: # gen phase + k_min = torch.min(k_min, + k_out[:, cache_position, :, :self.head_dim]) + k_max = torch.max(k_max, k_out[:, cache_position, :, + self.head_dim:]) + k_value = torch.cat([k_min, k_max], dim=-1) + access_type = self._access_type[k_value.dtype.itemsize] + k_out.view(dtype=access_type).index_copy_( + 1, cache_position, k_value.view(dtype=access_type)) + + return k_out[:, :math.ceil(seq_len / self.page_size), :, :] + + def _rocketkv_selection(self, q: Tensor, k: Tensor, + metadata: RocketVanillaAttentionMetadata, + past_seen_token: int, sample_idx: int) -> Tensor: + """Implement RocketKV's two-stage selection process for generation phase.""" + bsz = 1 + q_len = q.size(2) + + # Helper functions + def _gather(t: Tensor, dim: int, i: Tensor) -> Tensor: + dim += (dim < 0) * t.ndim + return t.gather( + dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1:])) + + @torch.compile(disable=not torch.cuda.is_available()) + def _scaled_softmax(x: Tensor, divscale: Tensor | float, + dim: int) -> Tensor: + return torch.softmax(x / divscale, dim=dim) + + # Get KT cache for key-token matching + kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers( + self.layer_idx) + target_seq_len = past_seen_token + 1 # +1 for current token + + # Update KT cache + kt_cache_position = torch.arange(past_seen_token // self.page_size, + math.ceil(target_seq_len / + self.page_size), + device=q.device) + kt_states = self._single_request_update_kt_cache( + k, kt_cache_tensor, metadata.kt_cache_block_offsets[sample_idx], + target_seq_len, kt_cache_position) + + # Reshape query for multi-head processing + qi = q.view(bsz, self.num_kv_heads, self.num_heads // self.num_kv_heads, + q_len, self.head_dim) + qi_abs = torch.abs(qi) + + # Top-r selection on query features + i1 = torch.topk(qi_abs.mean(dim=2, keepdim=True), self.topr, + dim=-1).indices + qi_hat = _gather(qi, -1, i1) + + # Generate signed indices for key-token matching + i1_sign = torch.where( + qi_hat.sum(dim=2, keepdim=True) > 0, i1 + self.head_dim, + i1).transpose(-1, -2) + + # Gather key tokens and compute attention scores + kt_hat = _gather( + kt_states.permute(0, 2, 3, 1).unsqueeze(2), -2, i1_sign) + qk_hat = qi_hat @ kt_hat + qk_hat = qk_hat.repeat_interleave(self.page_size, + dim=-1)[:, :, :, :, :target_seq_len] + scale = torch.sqrt(self.head_dim * + torch.abs(qi_hat).sum(dim=-1, keepdim=True) / + qi_abs.sum(dim=-1, keepdim=True)) + + # (1, num_kv_heads, num_heads, target_seq_len) + s_hat = _scaled_softmax(qk_hat, scale, dim=-1) + + topk = min(self.topk, target_seq_len) + i2 = torch.topk(s_hat.mean(dim=2, keepdim=True), topk, dim=-1).indices + + iKV = i2[:, :, 0, 0, :].transpose(1, 2) # (1, topk, num_kv_heads) + + return iKV + + +class RocketKVCacheManager(KVCacheManager): + + def __init__( + self, + kv_cache_config: KvCacheConfig, + kv_cache_type: CacheTypeCpp, + *, + num_layers: int, + num_kv_heads: Union[int, List[Optional[int]]], + head_dim: int, + tokens_per_block: int, + # Note that max_seq_len is not necessarily equal to kv_cache_config.num_tokens. + # It's derived from the model's BuildConfig for consistency with the C++ backend. + max_seq_len: int, + max_batch_size: int, + mapping: Mapping, + dtype: DataType = DataType.HALF, + spec_config: Optional["DecodingBaseConfig"] = None, + layer_mask: Optional[List[bool]] = None, + max_num_tokens: int = 8192, + model_config: Optional[ModelConfig] = None, + max_beam_width: int = 1, + sparse_attn_config: Optional["SparseAttentionConfig"] = None, + **kwargs, + ) -> None: + + assert not kv_cache_config.enable_block_reuse, "RocketKV cache requires block reuse to be disabled in KV cache config" + self.kt_tokens_per_block = next_power_of_2( + math.ceil(tokens_per_block / sparse_attn_config.page_size)) + + super().__init__( + kv_cache_config, + kv_cache_type, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=dtype, + spec_config=spec_config, + layer_mask=layer_mask, + max_num_tokens=max_num_tokens, + model_config=model_config, + max_beam_width=max_beam_width, + **kwargs, + ) + self.page_size = sparse_attn_config.page_size + self.prompt_budget = sparse_attn_config.prompt_budget + self.max_batch_size = max_batch_size + + # Per layer KT cache pool + # Use the same number of blocks as the paged kv cache. In this way, the scheduler can use the same number of + # blocks to schedule requests. + # Use kt_tokens_per_block to make sure the KT cache is large enough to hold the kt tokens, + # since kt_tokens_per_block * num_blocks * page_size >= tokens_per_block * num_blocks. + self.num_blocks = self.blocks_in_primary_pool + self.kt_cache_pool_per_layer = [ + torch.empty((self.num_blocks, self.kt_tokens_per_block, + num_kv_heads, head_dim * 2), + device="cuda", + dtype=torch.bfloat16) + for _ in range(self.num_local_layers) + ] + self.base_kt_block_offsets = torch.arange(self.num_blocks, + device="cpu", + dtype=torch.int32) + self.max_kt_blocks_per_seq = self.num_blocks + + # Block manager to manage the KT cache blocks for each request. Different layers share the + # same block ids. + self.paged_kt_block_ids = dict() + self.free_blocks = deque(range(self.num_blocks)) + + def add_dummy_requests( + self, + request_ids: List[int], + token_nums: Optional[List[int]] = None, + is_gen: bool = False, + prepare_resource: bool = True, + max_num_draft_tokens: int = 0, + use_mrope: bool = False, + max_beam_width: int = 1, + ): + requests = super().add_dummy_requests( + request_ids=request_ids, + token_nums=token_nums, + is_gen=is_gen, + prepare_resource=prepare_resource, + max_num_draft_tokens=max_num_draft_tokens, + use_mrope=use_mrope, + max_beam_width=max_beam_width, + ) + if prepare_resource: + for req in requests: + request_id = req.py_request_id + kt_token_num = math.ceil(req.max_beam_num_tokens / + self.page_size) + self.add_kt_tokens(request_id, kt_token_num) + return requests + + def get_kt_buffers(self, layer_idx: int): + return self.kt_cache_pool_per_layer[layer_idx] + + def get_kt_block_offsets(self, request_ids: List[int]) -> torch.Tensor: + kt_block_offsets = torch.empty( + [len(request_ids), self.max_kt_blocks_per_seq], + device="cpu", + dtype=torch.int32) + for i in range(len(request_ids)): + block_ids = self.paged_kt_block_ids[request_ids[i]] + block_num = len(block_ids) + kt_block_offsets[i, 0:block_num].copy_( + self.base_kt_block_offsets[torch.tensor(block_ids, + dtype=torch.int32, + device="cpu")]) + return kt_block_offsets + + def prepare_resources(self, scheduled_batch): + super().prepare_resources(scheduled_batch) + for req in scheduled_batch.all_requests(): + request_id = req.py_request_id + kt_token_num = math.ceil(req.max_beam_num_tokens / self.page_size) + self.add_kt_tokens(request_id, kt_token_num) + + def update_resources(self, scheduled_batch): + for request in scheduled_batch.context_requests: + if request.state != LlmRequestState.GENERATION_COMPLETE: + seq_len = request.get_num_tokens(0) + rewind_len = max(seq_len - 1 - self.prompt_budget, 0) + self.rewind_kv_cache(request, rewind_len) + self.rewind_kt_cache(request, rewind_len) + + def rewind_kt_cache(self, request, rewind_len): + request_id = request.py_request_id + num_tokens = request.max_beam_num_tokens + updated_kt_token_num = math.ceil(num_tokens - + rewind_len / self.page_size) + page_count_needed = self.compute_page_count(updated_kt_token_num, + self.kt_tokens_per_block) + num_rewind_pages = len( + self.paged_kt_block_ids[request_id]) - page_count_needed + if num_rewind_pages > 0: + self._free_kt_pages( + self.paged_kt_block_ids[request_id][-num_rewind_pages:]) + self.paged_kt_block_ids[request_id] = self.paged_kt_block_ids[ + request_id][:-num_rewind_pages] + + def free_resources(self, request): + super().free_resources(request) + request_id = request.py_request_id + self._free_kt_pages(self.paged_kt_block_ids[request_id]) + del self.paged_kt_block_ids[request_id] + + def add_kt_tokens(self, request_id: int, kt_token_num: int): + if kt_token_num > 0: + page_count_needed = self.compute_page_count( + kt_token_num, self.kt_tokens_per_block) + if request_id not in self.paged_kt_block_ids: + self.paged_kt_block_ids[request_id] = [] + if len(self.paged_kt_block_ids[request_id]) < page_count_needed: + new_page = self._allocate_kt_pages( + page_count_needed - + len(self.paged_kt_block_ids[request_id])) + self.paged_kt_block_ids[request_id].extend(new_page) + + def _allocate_kt_pages(self, page_count: int) -> list: + assert len(self.free_blocks) >= page_count, "Not enough pages." + pages = [self.free_blocks.popleft() for _ in range(page_count)] + return pages + + def _free_kt_pages(self, page_list: list): + self.free_blocks.extend(page_list) + + def compute_page_count(self, token_count: int, tokens_per_page: int) -> int: + return (token_count + tokens_per_page - 1) // tokens_per_page + + @staticmethod + def get_cache_size_per_token(model_config: ModelConfig, mapping: Mapping, + **kwargs): + # get kv cache dtype bytes + mem_per_token = 2 + quant_config = model_config.quant_config + if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache( + ): + mem_per_token = 1 + + # get num key value heads + config = model_config.pretrained_config + num_key_value_heads = getattr(config, 'num_key_value_heads', + config.num_attention_heads) + if isinstance(num_key_value_heads, Iterable): + num_key_value_heads = sum(num_key_value_heads) / len( + num_key_value_heads) + + # get head dim + tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size + head_dim = getattr(config, "head_dim", None) + if not isinstance(head_dim, int): + head_dim = config.hidden_size // config.num_attention_heads + head_dim = head_dim * num_key_value_heads // tp_size + + # provide at least 1 layer to prevent division by zero cache size + num_attention_layers = max( + len(mapping.pp_layers(model_config.get_num_attention_layers())), 1) + mem_per_token *= num_attention_layers * head_dim + + # K and V + # 2 for K and V, 2 * kt_tokens_per_block / tokens_per_block for KT cache + tokens_per_block = kwargs['tokens_per_block'] + sparse_attn_config = model_config.sparse_attention_config + kt_tokens_per_block = next_power_of_2( + math.ceil(tokens_per_block / sparse_attn_config.page_size)) + kv_factor = 2 + 2 * kt_tokens_per_block / tokens_per_block + mem_per_token *= kv_factor + return mem_per_token + + def get_cache_bytes_per_token(self): + # 2 for K and V, 2 * kt_tokens_per_block / tokens_per_block for KT cache + kv_factor = self.kv_factor + 2 * self.kt_tokens_per_block / self.tokens_per_block + cache_size_per_token = math.ceil( + kv_factor * sum(self.num_kv_heads_per_layer) * self.head_dim) + + if self.dtype not in (DataType.FP8, DataType.HALF, DataType.BF16, + DataType.FLOAT, DataType.NVFP4): + raise ValueError(f'Cannot support {self.dtype} KV cache.') + + cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token, + self.dtype) + if self.dtype == DataType.NVFP4: + cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes( + cache_size_per_token, + quant_vector_size=16, + scaling_factor_dtype=DataType.FP8) + return cache_size_bytes_per_token diff --git a/tensorrt_llm/_torch/attention_backend/sparse/utils.py b/tensorrt_llm/_torch/attention_backend/sparse/utils.py new file mode 100644 index 0000000000..917d42a23f --- /dev/null +++ b/tensorrt_llm/_torch/attention_backend/sparse/utils.py @@ -0,0 +1,39 @@ +from .rocket import (RocketKVCacheManager, RocketTrtllmAttention, + RocketVanillaAttention) + + +def get_sparse_attn_kv_cache_manager( + sparse_attn_config: "SparseAttentionConfig"): + if sparse_attn_config.algorithm == "rocket": + return RocketKVCacheManager + else: + raise ValueError( + f"Unsupported sparse attention algorithm: {sparse_attn_config.algorithm}" + ) + + +def get_vanilla_sparse_attn_attention_backend( + sparse_attn_config: "SparseAttentionConfig"): + if sparse_attn_config.algorithm == "rocket": + return RocketVanillaAttention + else: + raise ValueError( + f"Unsupported sparse attention algorithm in vanilla attention backend: {sparse_attn_config.algorithm}" + ) + + +def get_trtllm_sparse_attn_attention_backend( + sparse_attn_config: "SparseAttentionConfig"): + if sparse_attn_config.algorithm == "rocket": + return RocketTrtllmAttention + else: + raise ValueError( + f"Unsupported sparse attention algorithm in trtllm attention backend: {sparse_attn_config.algorithm}" + ) + + +def get_flashinfer_sparse_attn_attention_backend( + sparse_attn_config: "SparseAttentionConfig"): + raise ValueError( + f"Unsupported sparse attention algorithm in flashinfer attention backend: {sparse_attn_config.algorithm}" + ) diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 90bc6df784..e2e67b7a08 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -190,6 +190,10 @@ class TrtllmAttentionWrapper: spec_decoding_generation_lengths: Optional[torch.Tensor] = None, attention_sinks: Optional[torch.Tensor] = None, chunked_prefill_buffer_batch_size: int = 1, + sparse_kv_indices: Optional[torch.Tensor] = None, + sparse_kv_offsets: Optional[torch.Tensor] = None, + sparse_attn_indices: Optional[torch.Tensor] = None, + sparse_attn_offsets: Optional[torch.Tensor] = None, **kwargs, ): """ @@ -229,6 +233,10 @@ class TrtllmAttentionWrapper: helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU. attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU. chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens. + sparse_kv_indices (torch.Tensor): The sparse indices for the KV cache, with shape of (num_heads_kv, num_sparse_tokens) on GPU. + sparse_kv_offsets (torch.Tensor): The batch offsets for the sparse KV indices, with shape of (num_contexts + 1) on GPU. + sparse_attn_indices (torch.Tensor): The sparse indices for the attention layer, with shape of (num_heads_kv, num_sparse_tokens) on GPU. + sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU. """ self.layer_idx = layer_idx self.tokens_per_block = tokens_per_block @@ -266,7 +274,10 @@ class TrtllmAttentionWrapper: self.softmax_stats_tensor = softmax_stats_tensor self.helix_position_offsets = helix_position_offsets self.attention_sinks = attention_sinks - + self.sparse_kv_indices = sparse_kv_indices + self.sparse_kv_offsets = sparse_kv_offsets + self.sparse_attn_indices = sparse_attn_indices + self.sparse_attn_offsets = sparse_attn_offsets if max_sequence_length > self.rope_params.max_positions: self.rope_params.max_positions = max_sequence_length self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params( @@ -424,6 +435,12 @@ class TrtllmAttentionWrapper: self.spec_decoding_position_offsets, self.spec_decoding_packed_mask ] mla_tensor_params = [self.helix_position_offsets] + sparse_attention_params = [ + self.sparse_kv_indices, + self.sparse_kv_offsets, + self.sparse_attn_indices, + self.sparse_attn_offsets, + ] thop.attention( q, @@ -491,6 +508,7 @@ class TrtllmAttentionWrapper: self.softmax_stats_tensor, spec_decoding_bool_params, spec_decoding_tensor_params, + sparse_attention_params, ) # reset the planned states (especially tensors) to avoid memory leak self.plan() @@ -1239,6 +1257,14 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): use_paged_context_fmha=use_paged_context_fmha, is_mla_enable=self.is_mla_enable, ) + + sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None + if self.sparse_attention_config is not None: + sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict( + q, k, metadata) + sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict( + q, k, metadata) + self.wrapper.plan( layer_idx=self.get_local_layer_idx(metadata), tokens_per_block=metadata.tokens_per_block, @@ -1287,6 +1313,10 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): spec_decoding_generation_lengths, attention_sinks=attention_sinks, chunked_prefill_buffer_batch_size=chunked_prefill_buffer_batch_size, + sparse_kv_indices=sparse_kv_indices, + sparse_kv_offsets=sparse_kv_offsets, + sparse_attn_indices=sparse_attn_indices, + sparse_attn_offsets=sparse_attn_offsets, ) out_dtype = None if out_scale is not None: @@ -1500,3 +1530,27 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): self.num_heads, self.mla_params.v_head_dim, ) + + def sparse_attn_predict( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + metadata: TrtllmAttentionMetadata, + **kwargs, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Predict sparse attn indices. It's implemented in the derived class. + """ + raise NotImplementedError + + def sparse_kv_predict( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + metadata: TrtllmAttentionMetadata, + **kwargs, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Predict sparse kv indices. It's implemented in the derived class. + """ + raise NotImplementedError diff --git a/tensorrt_llm/_torch/attention_backend/utils.py b/tensorrt_llm/_torch/attention_backend/utils.py index b741ec37c7..8879541675 100644 --- a/tensorrt_llm/_torch/attention_backend/utils.py +++ b/tensorrt_llm/_torch/attention_backend/utils.py @@ -3,22 +3,33 @@ from typing import Optional, Type from ...models.modeling_utils import QuantConfig from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from .interface import AttentionBackend, MLAParams, PositionalEmbeddingParams +from .sparse import (get_flashinfer_sparse_attn_attention_backend, + get_trtllm_sparse_attn_attention_backend, + get_vanilla_sparse_attn_attention_backend) from .trtllm import TrtllmAttention from .vanilla import VanillaAttention -def get_attention_backend(backend_name: str) -> Type[AttentionBackend]: +def get_attention_backend( + backend_name: str, + sparse_attn_config: Optional["SparseAttentionConfig"] = None +) -> Type[AttentionBackend]: if backend_name == "VANILLA": + if sparse_attn_config is not None: + return get_vanilla_sparse_attn_attention_backend(sparse_attn_config) return VanillaAttention elif backend_name == "TRTLLM": + if sparse_attn_config is not None: + return get_trtllm_sparse_attn_attention_backend(sparse_attn_config) return TrtllmAttention elif backend_name == "FLASHINFER" and IS_FLASHINFER_AVAILABLE: from .flashinfer import FlashInferAttention - + if sparse_attn_config is not None: + return get_flashinfer_sparse_attn_attention_backend( + sparse_attn_config) return FlashInferAttention elif backend_name == "FLASHINFER_STAR_ATTENTION" and IS_FLASHINFER_AVAILABLE: from .star_flashinfer import StarAttention - return StarAttention return TrtllmAttention @@ -42,12 +53,13 @@ def create_attention( predicted_tokens_per_seq: Optional[int] = 1, skip_create_weights_in_init: bool = False, attention_chunk_size: Optional[int] = None, + sparse_attention_config: Optional["SparseAttentionConfig"] = None, ): if attention_chunk_size is not None and backend_name.upper() != "TRTLLM": raise ValueError( f"Backend {backend_name} does not support chunked attention.") - attn_cls = get_attention_backend(backend_name) + attn_cls = get_attention_backend(backend_name, sparse_attention_config) if is_mla_enable: assert attn_cls.support_mla( @@ -76,4 +88,5 @@ def create_attention( mla_params=mla_params, skip_create_weights_in_init=skip_create_weights_in_init, attention_chunk_size=attention_chunk_size, + sparse_attention_config=sparse_attention_config, ) diff --git a/tensorrt_llm/_torch/attention_backend/vanilla.py b/tensorrt_llm/_torch/attention_backend/vanilla.py index 125527455a..46c765e23c 100644 --- a/tensorrt_llm/_torch/attention_backend/vanilla.py +++ b/tensorrt_llm/_torch/attention_backend/vanilla.py @@ -13,6 +13,7 @@ except ImportError: from .interface import (AttentionBackend, AttentionMask, AttentionMetadata, PredefinedAttentionMask) +from .sparse.kernel import triton_index_gather def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -94,90 +95,157 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): self.num_key_value_groups = self.num_heads // self.num_kv_heads self.q_scaling = q_scaling - def _single_request_update_kv_cache(self, k, v, kv_cache_tensor, seq_len, - cache_idx, cache_position): + def _single_request_sparse_attn_predict( + self, q: torch.Tensor, k: Optional[torch.Tensor], + v: Optional[torch.Tensor], kv_cache_tensor: torch.Tensor, + metadata: AttentionMetadata, past_seen_token: int, sample_idx: int, + **kwargs) -> tuple[Optional[torch.Tensor], int]: + raise NotImplementedError + + def _single_request_sparse_kv_predict( + self, q: Optional[torch.Tensor], k: Optional[torch.Tensor], + v: Optional[torch.Tensor], metadata: AttentionMetadata, + past_seen_token: int, sample_idx: int, + **kwargs) -> tuple[Optional[torch.Tensor], int]: + raise NotImplementedError + + def _single_request_update_kv_cache(self, + k, + v, + kv_cache_tensor, + past_seen_token, + kv_len, + cache_idx, + sparse_kv_indices=None): + # select tokens using the sparse kv indices + if sparse_kv_indices is not None: + k_selected = triton_index_gather(k, sparse_kv_indices) + v_selected = triton_index_gather(v, sparse_kv_indices) + else: + k_selected, v_selected = k, v + + # get cache position + seq_len = past_seen_token + kv_len + cache_position = torch.arange(past_seen_token, + seq_len, + device=kv_cache_tensor.device) + + # get kv cache tensor k_out = kv_cache_tensor[cache_idx, 0, :, :, :].unsqueeze(0) v_out = kv_cache_tensor[cache_idx, 1, :, :, :].unsqueeze(0) + # update kv cache if k is not None and v is not None: - access_type = self._access_type[k.dtype.itemsize] - k_out.view(dtype=access_type).index_copy_(1, cache_position, - k.view(dtype=access_type)) - v_out.view(dtype=access_type).index_copy_(1, cache_position, - v.view(dtype=access_type)) + access_type = self._access_type[k_selected.dtype.itemsize] + k_out.view(dtype=access_type).index_copy_( + 1, cache_position, k_selected.view(dtype=access_type)) + v_out.view(dtype=access_type).index_copy_( + 1, cache_position, v_selected.view(dtype=access_type)) - return k_out[:, :seq_len, :, :], v_out[:, :seq_len, :, :] - - def _single_request_forward(self, - q, - k, - v, - attention_mask: AttentionMask, - kv_cache_tensor, - past_seen_token, - cache_idx, - attention_window_size: Optional[int] = None): + # return past kv and the dense kv tensors for sparse attention + if sparse_kv_indices is not None: + k_states = torch.cat([k_out[:, :past_seen_token, :, :], k], dim=1) + v_states = torch.cat([v_out[:, :past_seen_token, :, :], v], dim=1) + else: + k_states, v_states = k_out[:, :seq_len, :, :], v_out[:, : + seq_len, :, :] + return k_states, v_states + def _single_request_preprocess_inputs(self, q, k, v, kv_dtype): bsz = 1 q_len = q.size(0) - # Query q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - # Key and Value - target_seq_len = past_seen_token + kv_len = 0 if k is not None and v is not None: kv_len = k.size(0) k = k.view(bsz, kv_len, self.num_kv_heads, self.head_dim) v = v.view(bsz, kv_len, self.num_kv_heads, self.head_dim) - target_seq_len += kv_len if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant( ): qc = self.quant_config if qc.layer_quant_mode.has_fp8_kv_cache(): - assert kv_cache_tensor.dtype == torch.float8_e4m3fn, f"KV cache should have fp8 dtype, but get {kv_cache_tensor.dtype}" + assert kv_dtype == torch.float8_e4m3fn, \ + f"KV cache should have fp8 dtype, but get {kv_dtype}" k = k.to(torch.float8_e4m3fn) v = v.to(torch.float8_e4m3fn) - assert k.dtype == v.dtype == kv_cache_tensor.dtype, f"KV cache dtype {kv_cache_tensor.dtype} does not match k/v dtype {k.dtype}/{v.dtype}" + assert k.dtype == v.dtype == kv_dtype, \ + f"KV cache dtype {kv_dtype} does not match k/v dtype {k.dtype}/{v.dtype}" - cache_position = torch.arange(past_seen_token, - target_seq_len, - device=q.device) + return q, k, v, kv_len - key_states, value_states = self._single_request_update_kv_cache( - k, v, kv_cache_tensor, target_seq_len, cache_idx, cache_position) + def _single_request_create_attention_mask(self, + attention_mask, + past_seen_token, + kv_len, + q_device, + q_len, + attention_window_size=None): + """ + Create appropriate attention mask based on the attention type. - key_states = key_states.transpose(1, 2).to(q.dtype) - value_states = value_states.transpose(1, 2).to(q.dtype) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # Attention Mask + Returns: + Tuple of (is_causal, attn_mask) + """ + bsz = 1 is_causal = False attn_mask = None + + # get cache position + seq_len = past_seen_token + kv_len + cache_position = torch.arange(past_seen_token, seq_len, device=q_device) + + # create attention mask if attention_mask == PredefinedAttentionMask.CAUSAL: # Create custom sliding window mask as sdpa doesn't natively support it. if attention_window_size is not None: attn_mask = generate_sliding_window_mask( - bsz, target_seq_len, cache_position, q.device, + bsz, seq_len, cache_position, q_device, attention_window_size) elif past_seen_token == 0: is_causal = True elif q_len != 1: # attn_mask: 4-D tensor (batch_size, 1, query_seq_len, seq_len) - attn_mask = generate_causal_mask(bsz, target_seq_len, - cache_position, q.device) + attn_mask = generate_causal_mask(bsz, seq_len, cache_position, + q_device) elif attention_mask == PredefinedAttentionMask.FULL: pass else: raise ValueError("Unexpected attention mask type") + return attn_mask, is_causal + + def _single_request_attn_forward(self, + q, + key_states, + value_states, + is_causal, + attn_mask, + sparse_indices=None): + """ + Common attention computation using scaled dot-product attention. + """ + # select the key and value states using the sparse indices + if sparse_indices is not None: + key_states = triton_index_gather(key_states, sparse_indices) + value_states = triton_index_gather(value_states, sparse_indices) + + # transpose kv + key_states = key_states.transpose(1, 2).to(q.dtype) + value_states = value_states.transpose(1, 2).to(q.dtype) + + # repeat kv to support MQA/GQA + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # get qk scale qk_scale = None if self.q_scaling is not None: qk_scale = 1 / (math.sqrt(self.head_dim) * self.q_scaling) + # attention attn_output = torch.nn.functional.scaled_dot_product_attention( q, key_states, @@ -187,21 +255,66 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): scale=qk_scale, ) - attn_output = attn_output.squeeze(0) return attn_output - @staticmethod + def _single_request_forward(self, + q, + k, + v, + attention_mask: AttentionMask, + kv_cache_tensor, + past_seen_token, + cache_idx, + sample_idx, + metadata: AttentionMetadata, + attention_window_size: Optional[int] = None, + **kwargs): + # preprocess inputs + q, k, v, kv_len = self._single_request_preprocess_inputs( + q, k, v, kv_cache_tensor.dtype) + + # predict sparse kv indices + sparse_kv_indices = None + if self.sparse_attention_config is not None: + sparse_kv_indices, kv_len = self._single_request_sparse_kv_predict( + q, k, v, metadata, past_seen_token, sample_idx) + + # update kv cache + key_states, value_states = self._single_request_update_kv_cache( + k, v, kv_cache_tensor, past_seen_token, kv_len, cache_idx, + sparse_kv_indices) + + # predict sparse attn indices + sparse_indices = None + if self.sparse_attention_config is not None: + sparse_indices, kv_len = self._single_request_sparse_attn_predict( + q, k, v, kv_cache_tensor, metadata, past_seen_token, sample_idx) + + # create attention mask + attn_mask, is_causal = self._single_request_create_attention_mask( + attention_mask, past_seen_token, kv_len, q.device, q.size(2), + attention_window_size) + + # attention + attn_output = self._single_request_attn_forward(q, key_states, + value_states, is_causal, + attn_mask, + sparse_indices) + + return attn_output.squeeze(0) + def no_kv_cache_forward( - q: torch.Tensor, - k: Optional[torch.Tensor], - v: Optional[torch.Tensor], - num_heads: int, - num_kv_heads: int, - metadata: AttentionMetadata, - *, - attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL, - position_ids: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + num_heads: int, + num_kv_heads: int, + metadata: AttentionMetadata, + *, + attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL, + position_ids: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: """ This function is used to perform attention without kv cache. Args: @@ -275,14 +388,14 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): # try to separate the kv cache estimation path from no kv cache attn. num_heads = self.num_heads num_kv_heads = self.num_kv_heads - return VanillaAttention.no_kv_cache_forward( - q=q, - k=k, - v=v, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - metadata=metadata, - attention_mask=attention_mask) + return self.no_kv_cache_forward(q=q, + k=k, + v=v, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + metadata=metadata, + attention_mask=attention_mask, + **kwargs) past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq cache_indices = [ @@ -298,7 +411,7 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): offset = 0 offset_kv = 0 attn_outputs = [] - for i, (seq_len, seq_len_kv) in enumerate( + for sample_idx, (seq_len, seq_len_kv) in enumerate( zip(metadata.seq_lens, metadata.seq_lens_kv)): single_q = q[offset:offset + seq_len] single_k = k[ @@ -306,13 +419,16 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): seq_len_kv] if k is not None and seq_len_kv != 0 else None single_v = v[ offset_kv:offset_kv + - seq_len_kv] if k is not None and seq_len_kv != 0 else None - past_seen_token = past_seen_tokens[i] - cache_idx = cache_indices[i] + seq_len_kv] if v is not None and seq_len_kv != 0 else None + + past_seen_token = past_seen_tokens[sample_idx] + cache_idx = cache_indices[sample_idx] attn_output = self._single_request_forward( single_q, single_k, single_v, attention_mask, kv_cache_tensor, - past_seen_token, cache_idx, attention_window_size) + past_seen_token, cache_idx, sample_idx, metadata, + attention_window_size, **kwargs) + attn_outputs.append(attn_output) offset += seq_len diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 5a225a566b..e1b817d253 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -121,6 +121,7 @@ class ModelConfig(Generic[TConfig]): spec_config: Optional["DecodingBaseConfig"] = None lora_config: Optional["LoraConfig"] = None + sparse_attention_config: Optional["SparseAttentionConfig"] = None is_generation: bool = True max_num_tokens: int = 8192 diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index eb0071e2b5..50710d6d4f 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -259,7 +259,9 @@ class Attention(nn.Module): self.quant_config = config.get_quant_config() self.attn_backend = config.attn_backend - attn_cls = get_attention_backend(self.attn_backend) + attn_cls = get_attention_backend( + self.attn_backend, + sparse_attn_config=config.sparse_attention_config) # These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used, # but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora @@ -277,6 +279,9 @@ class Attention(nn.Module): # Whether to fuse RoPE into the attention OP. # If true, RoPE will be applied in self.attn.forward. # If false, RoPE will be applied in self.apply_rope. + if config.sparse_attention_config is not None: + logger.warning("disable rope_fusion for sparse attention.") + rope_fusion = False self.rope_fusion = rope_fusion if self.rope_fusion and not attn_cls.support_fused_rope(): logger.warning( @@ -314,6 +319,7 @@ class Attention(nn.Module): skip_create_weights_in_init=config.skip_create_weights_in_init, q_scaling=self.q_scaling, attention_chunk_size=self.attention_chunk_size, + sparse_attention_config=config.sparse_attention_config, ) self.support_fused_qkv = self.attn.support_fused_qkv() @@ -854,6 +860,7 @@ class MLA(nn.Module): v_head_dim=self.v_head_dim, predicted_tokens_per_seq=self.predicted_tokens_per_seq, skip_create_weights_in_init=config.skip_create_weights_in_init, + sparse_attention_config=config.sparse_attention_config, ) self.mqa = create_attention( @@ -873,6 +880,7 @@ class MLA(nn.Module): v_head_dim=self.kv_lora_rank, predicted_tokens_per_seq=self.predicted_tokens_per_seq, skip_create_weights_in_init=config.skip_create_weights_in_init, + sparse_attention_config=config.sparse_attention_config, ) self.aux_stream = aux_stream diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index e81f222add..2fa2eeb476 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -1,6 +1,5 @@ import os import random -from collections.abc import Iterable from typing import Dict, List, Optional import torch @@ -12,14 +11,15 @@ from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str from tensorrt_llm.bindings.executor import DecodingMode from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig, MTPDecodingConfig, PeftCacheConfig, - SamplerType, SpeculativeConfig, - TorchLlmArgs) + SamplerType, SparseAttentionConfig, + SpeculativeConfig, TorchLlmArgs) from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import (LoraConfig, get_default_trtllm_modules_to_hf_modules) from tensorrt_llm.lora_manager import load_torch_lora from tensorrt_llm.mapping import CpType, Mapping +from ..attention_backend import get_sparse_attn_kv_cache_manager from ..model_config import ModelConfig from ..speculative import get_num_extra_kv_tokens, get_spec_decoder from .config import PyTorchConfig @@ -42,6 +42,20 @@ from .seq_slot_manager import SeqSlotManager GB = 1 << 30 +def get_kv_cache_manager_cls(model_config: ModelConfig): + config = model_config.pretrained_config + sparse_attn_config = model_config.sparse_attention_config + if is_mla(config): + return KVCacheManager + elif is_nemotron_hybrid(config): + return MambaHybridCacheManager + else: + if sparse_attn_config is not None: + return get_sparse_attn_kv_cache_manager(sparse_attn_config) + else: + return KVCacheManager + + class KvCacheCreator: """Groups together logic related to KV cache construction.""" @@ -61,6 +75,7 @@ class KvCacheCreator: kv_cache_config: KvCacheConfig, pytorch_backend_config: PyTorchConfig, speculative_config: SpeculativeConfig, + sparse_attention_config: SparseAttentionConfig, ): self._model_engine = model_engine self._draft_model_engine = draft_model_engine @@ -72,50 +87,14 @@ class KvCacheCreator: self._kv_connector_manager = kv_connector_manager self._pytorch_backend_config = pytorch_backend_config self._speculative_config = speculative_config + self._sparse_attention_config = sparse_attention_config self._tokens_per_block = tokens_per_block self._max_seq_len = max_seq_len self._max_batch_size = max_batch_size self._net_max_seq_len = net_max_seq_len self._dummy_reqs = None - - @staticmethod - def _get_cache_size_per_token(model_config: ModelConfig, - mapping: Mapping) -> int: - mem_per_token = 2 - quant_config = model_config.quant_config - if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache( - ): - mem_per_token = 1 - - config = model_config.pretrained_config - - num_key_value_heads = getattr(config, 'num_key_value_heads', - config.num_attention_heads) - if isinstance(num_key_value_heads, Iterable): - num_key_value_heads = sum(num_key_value_heads) / len( - num_key_value_heads) - - mla = is_mla(config) - tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size - - kv_factor = 2 - if mla: - # MLA has kv_lora_rank and qk_rope_head_dim - head_dim = config.kv_lora_rank + config.qk_rope_head_dim - kv_factor = 1 - else: - _head_dim = getattr(config, 'head_dim', None) - if not isinstance(_head_dim, int): - _head_dim = config.hidden_size // config.num_attention_heads - head_dim = _head_dim * num_key_value_heads // tp_size - - # provide at least 1 layer to prevent division by zero cache size - num_attention_layers = max( - len(mapping.pp_layers(model_config.get_num_attention_layers())), 1) - mem_per_token *= num_attention_layers * head_dim - # K and V - mem_per_token *= kv_factor - return mem_per_token + self._kv_cache_manager_cls = get_kv_cache_manager_cls( + model_engine.model.model_config) def _get_free_gpu_memory_fraction(self) -> float: fraction = self._kv_cache_config.free_gpu_memory_fraction @@ -126,12 +105,14 @@ class KvCacheCreator: def _get_kv_size_per_token(self): model_config = self._model_engine.model.model_config mapping = self._mapping - kv_size_per_token = self._get_cache_size_per_token( - model_config, mapping) + kv_size_per_token = self._kv_cache_manager_cls.get_cache_size_per_token( + model_config, mapping, tokens_per_block=self._tokens_per_block) if self._draft_model_engine is not None: draft_model_config = self._draft_model_engine.model.model_config - kv_size_per_token += self._get_cache_size_per_token( - draft_model_config, mapping) + kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token( + draft_model_config, + mapping, + tokens_per_block=self._tokens_per_block) return kv_size_per_token def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction, @@ -231,6 +212,12 @@ class KvCacheCreator: estimating_kv_cache = True self._kv_cache_config.max_tokens = self._get_token_num_for_estimation( ) + model_config = self._model_engine.model.model_config + if model_config.attn_backend == "VANILLA": + logger.info( + "KV cache size estimation is not supported for Vanilla attention backend, disable it." + ) + estimating_kv_cache = False return estimating_kv_cache def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None: @@ -374,6 +361,7 @@ class KvCacheCreator: config = model_engine.model.model_config.pretrained_config quant_config = model_engine.model.model_config.quant_config spec_config = self._speculative_config + sparse_attn_config = self._sparse_attention_config hidden_size = config.hidden_size num_attention_heads = config.num_attention_heads @@ -396,7 +384,7 @@ class KvCacheCreator: num_hidden_layers = config.num_hidden_layers if is_mla(config): - kv_cache_manager = KVCacheManager( + kv_cache_manager = self._kv_cache_manager_cls( self._kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType. SELFKONLY, @@ -434,8 +422,7 @@ class KvCacheCreator: mamba_layer_mask = [ char == "M" for char in config.hybrid_override_pattern ] - - kv_cache_manager = MambaHybridCacheManager( + kv_cache_manager = self._kv_cache_manager_cls( # mamba cache parameters config.ssm_state_size, config.conv_kernel, @@ -518,7 +505,7 @@ class KvCacheCreator: binding_model_config = model_engine.model.model_config.get_bindings_model_config( tokens_per_block=self._tokens_per_block) if is_vswa else None - kv_cache_manager = KVCacheManager( + kv_cache_manager = self._kv_cache_manager_cls( self._kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, num_layers=num_hidden_layers, @@ -536,6 +523,7 @@ class KvCacheCreator: is_draft=model_engine.is_draft_model, kv_connector_manager=self._kv_connector_manager if not estimating_kv_cache else None, + sparse_attn_config=sparse_attn_config, ) # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER: diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 6db3116916..53d98fc83d 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -137,6 +137,7 @@ class PyTorchModelEngine(ModelEngine): attn_runtime_features: Optional[AttentionRuntimeFeatures] = None, dist: Optional[MPIDist] = None, spec_config: Optional["DecodingBaseConfig"] = None, + sparse_attention_config: Optional["SparseAttentionConfig"] = None, lora_config: Optional[LoraConfig] = None, is_draft_model: bool = False, drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], @@ -164,6 +165,7 @@ class PyTorchModelEngine(ModelEngine): spec_config.max_draft_len = 0 self.spec_config = spec_config self.is_spec_decode = spec_config is not None + self.sparse_attention_config = sparse_attention_config self.enable_spec_decode = self.is_spec_decode self.is_draft_model = is_draft_model @@ -175,6 +177,7 @@ class PyTorchModelEngine(ModelEngine): pytorch_backend_config=pytorch_backend_config, mapping=self.mapping, spec_config=self.spec_config, + sparse_attention_config=self.sparse_attention_config, max_num_tokens=max_num_tokens, max_seq_len=max_seq_len, lora_config=lora_config, @@ -261,7 +264,8 @@ class PyTorchModelEngine(ModelEngine): self.is_warmup = False self.attn_backend = get_attention_backend( - pytorch_backend_config.attn_backend) + pytorch_backend_config.attn_backend, + sparse_attn_config=sparse_attention_config) if self.is_spec_decode: self.spec_metadata = None @@ -794,7 +798,8 @@ class PyTorchModelEngine(ModelEngine): enable_flash_mla=self.model.model_config.enable_flash_mla, enable_context_mla_with_cached_kv= enable_context_mla_with_cached_kv, - cache_indirection=cache_indirection) + cache_indirection=cache_indirection, + sparse_attention_config=self.sparse_attention_config) if self.attn_metadata is not None: # This assertion can be relaxed if needed: just create a new metadata @@ -811,7 +816,8 @@ class PyTorchModelEngine(ModelEngine): runtime_features=self.attn_runtime_features, enable_flash_mla=self.model.model_config.enable_flash_mla, enable_context_mla_with_cached_kv=enable_context_mla_with_cached_kv, - cache_indirection=cache_indirection) + cache_indirection=cache_indirection, + sparse_attention_config=self.sparse_attention_config) return self.attn_metadata diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 2909b29cac..18c276e774 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -160,6 +160,7 @@ class ModelLoader: pytorch_backend_config: PyTorchConfig, mapping: Mapping, spec_config: Optional["DecodingBaseConfig"], + sparse_attention_config: Optional["SparseAttentionConfig"], max_num_tokens: int, max_seq_len: Optional[int], lora_config: Optional[LoraConfig] = None): @@ -177,6 +178,7 @@ class ModelLoader: self.pytorch_backend_config = pytorch_backend_config self.mapping = mapping self.spec_config = spec_config + self.sparse_attention_config = sparse_attention_config self.max_num_tokens = max_num_tokens self.max_seq_len = max_seq_len self.lora_config = lora_config @@ -294,6 +296,7 @@ class ModelLoader: force_dynamic_quantization=self.pytorch_backend_config. force_dynamic_quantization, spec_config=self.spec_config, + sparse_attention_config=self.sparse_attention_config, max_num_tokens=self.max_num_tokens, max_seq_len=self.max_seq_len, moe_max_num_tokens=self.pytorch_backend_config.moe_max_num_tokens, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index dd6f77f2a6..b3162f1f7d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -252,6 +252,8 @@ def create_py_executor( max_num_tokens = 8192 tokens_per_block = kv_cache_config.tokens_per_block + if pytorch_backend_config.attn_backend == "VANILLA": + tokens_per_block = max_num_tokens if pytorch_backend_config.attn_backend in [ "FLASHINFER", "FLASHINFER_STAR_ATTENTION" @@ -308,6 +310,8 @@ def create_py_executor( has_draft_model_engine = spec_config.spec_dec_mode.has_draft_model() has_spec_drafter = spec_config.spec_dec_mode.has_spec_drafter() + sparse_attention_config = llm_args.sparse_attention_config + # chunk_unit_size may be changed to 64 when using flash mla attn_runtime_features = AttentionRuntimeFeatures( chunked_prefill=enable_chunked_context, @@ -331,6 +335,7 @@ def create_py_executor( attn_runtime_features=attn_runtime_features, dist=dist, spec_config=spec_config, + sparse_attention_config=sparse_attention_config, lora_config=lora_config, checkpoint_loader=checkpoint_loader, ) @@ -568,6 +573,7 @@ def create_py_executor( kv_cache_config=kv_cache_config, pytorch_backend_config=pytorch_backend_config, speculative_config=spec_config, + sparse_attention_config=sparse_attention_config, ) estimating_kv_cache = kv_cache_creator.try_prepare_estimation() with mem_monitor.observe_creation_stage( diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 8e01ae1b95..82249e6748 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -3,7 +3,7 @@ import enum import math from abc import ABC, abstractmethod from collections import OrderedDict, defaultdict -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import torch @@ -165,6 +165,7 @@ class KVCacheManager(BaseResourceManager): max_beam_width: int = 1, is_draft: bool = False, kv_connector_manager: Optional[KvCacheConnectorManager] = None, + **kwargs, ) -> None: self.mapping = mapping self.dtype = dtype @@ -543,6 +544,64 @@ class KVCacheManager(BaseResourceManager): return get_size_in_bytes(cache_size // quant_vector_size, scaling_factor_dtype) + # TODO: refactor get_cache_size_per_token and get_cache_bytes_per_token to use the same logic + @staticmethod + def get_cache_size_per_token(model_config: ModelConfigPython, + mapping: Mapping, **kwargs): + # get kv cache dtype bytes + mem_per_token = 2 + quant_config = model_config.quant_config + if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache( + ): + mem_per_token = 1 + + # get num key value heads + config = model_config.pretrained_config + num_key_value_heads = getattr(config, 'num_key_value_heads', + config.num_attention_heads) + if isinstance(num_key_value_heads, Iterable): + num_key_value_heads = sum(num_key_value_heads) / len( + num_key_value_heads) + + # get head dim + mla = hasattr(config, "kv_lora_rank") + if mla: + head_dim = config.kv_lora_rank + config.qk_rope_head_dim + kv_factor = 1 + else: + tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size + head_dim = getattr(config, "head_dim", None) + if not isinstance(head_dim, int): + head_dim = config.hidden_size // config.num_attention_heads + head_dim = head_dim * num_key_value_heads // tp_size + kv_factor = 2 + + # provide at least 1 layer to prevent division by zero cache size + num_attention_layers = max( + len(mapping.pp_layers(model_config.get_num_attention_layers())), 1) + mem_per_token *= num_attention_layers * head_dim + + # K and V + mem_per_token *= kv_factor + return mem_per_token + + def get_cache_bytes_per_token(self): + cache_size_per_token = self.kv_factor * sum( + self.num_kv_heads_per_layer) * self.head_dim + + if self.dtype not in (DataType.FP8, DataType.HALF, DataType.BF16, + DataType.FLOAT, DataType.NVFP4): + raise ValueError(f'Cannot support {self.dtype} KV cache.') + + cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token, + self.dtype) + if self.dtype == DataType.NVFP4: + cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes( + cache_size_per_token, + quant_vector_size=16, + scaling_factor_dtype=DataType.FP8) + return cache_size_bytes_per_token + def calculate_max_num_blocks(self, kv_cache_config: KvCacheConfig, head_dim: int, @@ -554,20 +613,7 @@ class KVCacheManager(BaseResourceManager): if kv_cache_config.free_gpu_memory_fraction is not None else 0.9) - cache_size_per_token = kv_factor * sum( - self.num_kv_heads_per_layer) * head_dim - - if dtype not in (DataType.FP8, DataType.HALF, DataType.BF16, - DataType.FLOAT, DataType.NVFP4): - raise ValueError(f'Cannot support {dtype} KV cache.') - - cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token, - dtype) - if dtype == DataType.NVFP4: - cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes( - cache_size_per_token, - quant_vector_size=16, - scaling_factor_dtype=DataType.FP8) + cache_size_bytes_per_token = self.get_cache_bytes_per_token() free_mem, total_mem = torch.cuda.mem_get_info() diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index adc0e7e35c..f79a29cac4 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -12,6 +12,7 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType, ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs, LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, + RocketSparseAttentionConfig, SaveHiddenStatesDecodingConfig, SchedulerConfig, TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, UserProvidedDecodingConfig) @@ -61,4 +62,5 @@ __all__ = [ 'AttentionDpConfig', 'LoRARequest', 'SaveHiddenStatesDecodingConfig', + 'RocketSparseAttentionConfig', ] diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 9ace69ac91..31d38db73e 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -166,6 +166,63 @@ class CudaGraphConfig(StrictBaseModel): return batch_sizes +class BaseSparseAttentionConfig(StrictBaseModel): + """ + Configuration for sparse attention. + """ + algorithm: Literal["rocket"] = Field( + default="rocket", description="The algorithm for sparse attention.") + + @classmethod + def from_dict(cls, data: dict): + # dispatch to the correct sparse attention config + config_classes = { + "rocket": RocketSparseAttentionConfig, + } + + algorithm = data.get("algorithm", None) + if algorithm is None: + raise ValueError(f"Sparse attention algorithm is required") + + config_class = config_classes.get(algorithm.lower()) + if config_class is None: + raise ValueError(f"Invalid algorithm: {algorithm}") + + return config_class(**data) + + def _check_fields(self): + pass + + def supports_backend(self, backend: str) -> bool: + """ + Override if the speculation algorithm does not support + a subset of the possible backends. + """ + return True + + +class RocketSparseAttentionConfig(BaseSparseAttentionConfig): + """ + Configuration for rocket sparse attention. + """ + window_size: Optional[int] = Field( + default=None, description="The window size for snap KV.") + kernel_size: Optional[int] = Field( + default=None, description="The kernel size for snap KV.") + topr: Optional[Union[int, float]] = Field(default=76, description="Top-r") + topk: Optional[int] = Field(default=128, description="Top-k") + prompt_budget: Optional[int] = Field(default=1266, + description="Prompt budget") + page_size: Optional[int] = Field(default=3, description="Page size") + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + def supports_backend(self, backend: str) -> bool: + return backend == "pytorch" + + class MoeConfig(StrictBaseModel): """ Configuration for MoE. @@ -1133,6 +1190,10 @@ SpeculativeConfig: TypeAlias = Optional[Union[ AutoDecodingConfig, ]] +SparseAttentionConfig: TypeAlias = Union[ + RocketSparseAttentionConfig, +] + @PybindMirror.mirror_pybind_fields(_KvCacheConfig) class KvCacheConfig(StrictBaseModel, PybindMirror): @@ -1510,6 +1571,12 @@ class BaseLlmArgs(StrictBaseModel): description="Cache transceiver config.", status="prototype") + # Sparse attention config + sparse_attention_config: Optional[SparseAttentionConfig] = Field( + default=None, + description="Sparse attention config.", + status="prototype") + # Speculative decoding parameters speculative_config: SpeculativeConfig = Field( default=None, description="Speculative decoding config.") diff --git a/tests/unittest/_torch/attention/sparse/test_rocketkv.py b/tests/unittest/_torch/attention/sparse/test_rocketkv.py new file mode 100644 index 0000000000..34ac891d75 --- /dev/null +++ b/tests/unittest/_torch/attention/sparse/test_rocketkv.py @@ -0,0 +1,80 @@ +import json +import os + +import pytest +from utils.llm_data import llm_models_root + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig + + +@pytest.mark.parametrize("backend", ["pytorch"]) +@pytest.mark.parametrize("model_name", + ["llama-3.1-model/Llama-3.1-8B-Instruct"]) +@pytest.mark.parametrize("attention_backend", ["VANILLA", "TRTLLM"]) +def test_model(backend, model_name, attention_backend): + model_dir = str(llm_models_root() / model_name) + max_batch_size = 16 + max_output_tokens = 128 + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7, + enable_block_reuse=False) + + sparse_attention_config = RocketSparseAttentionConfig( + window_size=32, + kernel_size=63, + prompt_budget=2048, + ) + + llm = LLM( + model=model_dir, + backend=backend, + kv_cache_config=kv_cache_config, + attn_backend=attention_backend, + sparse_attention_config=sparse_attention_config, + max_batch_size=max_batch_size, + max_seq_len=8192, + max_num_tokens=8192, + cuda_graph_config= + None, # sparse attention does not support cuda graph now + ) + + inputs, references = [], [] + current_file = os.path.abspath(__file__) + current_dir = os.path.dirname(os.path.dirname( + os.path.dirname(current_file))) + input_file = f'{current_dir}/multi_gpu/test_star_attention_input.jsonl' + with open(input_file, 'r') as f: + for line in f: + sample = json.loads(line) + inputs.append({ + 'prompt': + sample['input_context'] + sample['input_query'], + }) + references.append(sample['outputs'][0]) + + with llm: + outputs = llm.generate( + inputs, + use_tqdm=True, + sampling_params=SamplingParams(add_special_tokens=False, + max_tokens=max_output_tokens, + temperature=0.8, + top_p=0.95), + ) + + count = 0 + for ref, ret in zip(references, outputs): + print(f"ret: {ret.outputs[0].text}") + print(f"ref: {ref}") + if ref not in ret.outputs[0].text: + print(f'reference {ref} is not in the output {ret.outputs[0].text}') + else: + count = count + 1 + acc = count / len(outputs) + + assert acc >= 0.9, 'accuracy test of rocketkv sparse attention failed' + + +if __name__ == '__main__': + test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "VANILLA") + test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "TRTLLM") diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 0c4a583ddc..1e535c99e9 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -183,6 +183,10 @@ methods: annotation: Optional[Literal["rpc", "ray"]] default: null status: prototype + sparse_attention_config: + annotation: Optional[tensorrt_llm.llmapi.llm_args.SparseAttentionConfig] + default: null + status: prototype return_annotation: None generate: parameters: