[TRTLLM-8536][feat] Add the sparse attention framework and one use case--RocketKV support (#8086)

Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
Co-authored-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
Fanrong Li 2025-10-14 23:23:16 +08:00 committed by GitHub
parent 7291cdc422
commit 0d20a8fd61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 5146 additions and 206 deletions

View File

@ -24,6 +24,7 @@
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
@ -287,6 +288,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.output_sf = generationsParams.context_buf_sf;
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf;
// Parameters for sparse attention
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
// Cross attention parameters.
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
@ -813,7 +817,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
}
size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t max_num_seq,
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept
{
if (max_num_tokens == 0)
{
@ -909,11 +913,15 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
size_t xqa_workspace_size = 0;
if (mEnableXQA)
{
int const XQA_NUM_BUFFERS = 7;
int const XQA_NUM_BUFFERS = 8;
size_t xqa_workspaces[XQA_NUM_BUFFERS];
size_t const cu_seqlens_size = sizeof(int) * (batch_beam + 1);
size_t const cu_kv_seqlens_size = sizeof(int) * (batch_beam + 1);
size_t const rotary_inv_freq_size = sizeof(float) * batch_beam * mRotaryEmbeddingDim / 2;
// Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets.
size_t const sparse_attn_cache_size = useTllmGenSparseAttention()
? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence) * mNumKVHeads
: 0;
xqa_workspaces[0] = cu_seqlens_size;
xqa_workspaces[1] = cu_kv_seqlens_size;
xqa_workspaces[2] = rotary_inv_freq_size;
@ -922,7 +930,8 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
// Scales used for trtllm-gen kernels.
xqa_workspaces[4] = sizeof(float) * 2;
xqa_workspaces[5] = sizeof(float);
xqa_workspaces[6] = mXqaDispatcher->getWorkspaceSize(
xqa_workspaces[6] = sparse_attn_cache_size;
xqa_workspaces[7] = mXqaDispatcher->getWorkspaceSize(
std::min<uint32_t>(mSpecDecodingMaxGenerationLength * max_num_seq, max_num_tokens));
xqa_workspace_size
= tc::calculateTotalWorkspaceSize(xqa_workspaces, XQA_NUM_BUFFERS, mXqaDispatcher->getWorkspaceAlignment());
@ -1647,6 +1656,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
preprocessingParams.spec_decoding_position_offsets = nullptr;
preprocessingParams.logn_scaling = params.logn_scaling_ptr;
// Sparse KV write
preprocessingParams.sparse_kv_indices = mRuntimeSparseAttentionParams.sparse_kv_indices;
preprocessingParams.sparse_kv_offsets = mRuntimeSparseAttentionParams.sparse_kv_offsets;
// Scalars
preprocessingParams.batch_size = params.batch_size;
preprocessingParams.max_input_seq_len = params.input_seq_length;
@ -1676,6 +1689,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
preprocessingParams.rotary_vision_start = mVisionStart;
preprocessingParams.rotary_vision_length = mVisionLength;
preprocessingParams.is_last_chunk
= !mAttentionChunkSize.has_value() || (params.input_seq_length == params.max_past_kv_length);
{
std::string const beforeRopeStr = "ctx attention before RoPE at layer " + std::to_string(mLayerIdx);
@ -1841,6 +1856,12 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
gatherInBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
sync_check_cuda_error(stream);
}
if (!mIsMLAEnabled) // Only for non-MLA attention
{
invokeKvCachePostprocessing(preprocessingParams, stream);
sync_check_cuda_error(stream);
}
}
else
{

View File

@ -26,6 +26,7 @@
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/kernels/mlaKernels.h"
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
#include "tensorrt_llm/kernels/xqaDispatcher.h"
#include <cassert>
#include <set>
@ -55,7 +56,7 @@ public:
int32_t cross_kv_length = 0, int32_t max_num_tokens = 0) const noexcept;
// total_num_seq is the sum of beam_width for multiple requests
[[nodiscard]] size_t getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t total_num_seq,
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept;
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept;
template <typename T>
class EnqueueParams
@ -156,14 +157,20 @@ public:
ss << "max_cyclic_attention_window_size: " << this->max_cyclic_attention_window_size << std::endl;
ss << "can_use_one_more_block: " << (this->can_use_one_more_block ? "true" : "false") << std::endl;
ss << "sink_token_length: " << this->sink_token_length << std::endl;
ss << "context_lengths: "
<< *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32,
runtime::ITensor::makeShape({batch_size})))
<< std::endl;
ss << "sequence_lengths: "
<< *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32,
runtime::ITensor::makeShape({batch_size})))
<< std::endl;
if (this->context_lengths && batch_size > 0)
{
ss << "context_lengths: "
<< *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32,
runtime::ITensor::makeShape({batch_size})))
<< std::endl;
}
if (this->sequence_lengths && batch_size > 0)
{
ss << "sequence_lengths: "
<< *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32,
runtime::ITensor::makeShape({batch_size})))
<< std::endl;
}
ss << "kv_scale_orig_quant: " << this->kv_scale_orig_quant << std::endl;
ss << "kv_scale_quant_orig: " << this->kv_scale_quant_orig << std::endl;
ss << "attention_output_orig_quant: " << this->attention_output_orig_quant << std::endl;
@ -348,6 +355,16 @@ public:
return mIsMLAEnabled;
}
[[nodiscard]] bool useSparseAttention() const
{
return mUseSparseAttention && mPagedKVCache && mEnableXQA;
}
[[nodiscard]] bool useTllmGenSparseAttention() const
{
return mUseTllmGenSparseAttention && useSparseAttention();
}
[[nodiscard]] int smVersion() const
{
return mSM;
@ -427,6 +444,8 @@ public:
bool mIsMLAEnabled = false;
bool mIsGenerationMLA = false;
bool mUseGenFlashMLA = false;
bool mUseSparseAttention = false;
bool mUseTllmGenSparseAttention = false;
tensorrt_llm::kernels::MlaMetaParams mMLAParams;
int mCpSize = 1;
int mCpRank = 0;
@ -454,6 +473,8 @@ public:
// Whether to fuse FP4 quant into attention kernel.
bool mFuseFp4Quant = false;
kernels::SparseAttentionParams mRuntimeSparseAttentionParams;
// This is implementation details which we want to save when serializing, but not expose as
// a plugin field or a constructor parameter
int32_t mNbMultiBlockSemaphores = 0;
@ -473,10 +494,11 @@ public:
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mUseTllmGenSparseAttention,
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
};
private:

View File

@ -233,6 +233,8 @@ struct XQALaunchParam
float* bmm2_scale_ptr = nullptr;
int32_t* semaphores = nullptr;
void* scratch = nullptr;
void* sparse_kv_block_offsets = nullptr;
int32_t* sparse_seq_lengths = nullptr;
};
// Setup launch params and ioScratch. ioScratch is for RoPE and output type conversion.
@ -266,6 +268,9 @@ void buildXQALaunchParams(XQALaunchParam<KVCacheBuffer>& launchParams, void*& in
const size_t cu_kv_seqlens_size = sizeof(int) * (batch_beam_size + 1);
const size_t rotary_inv_freq_size = sizeof(float) * batch_beam_size * params.rotary_embedding_dim / 2;
const size_t tokens_info_size = sizeof(int2) * params.total_num_input_tokens;
const size_t kv_block_offsets_size
= sizeof(int) * batch_beam_size * 2 * params.max_blocks_per_sequence * params.num_kv_heads;
const size_t seq_lengths_size = sizeof(int) * batch_beam_size * params.num_kv_heads;
launchParams.cu_seq_lens = reinterpret_cast<int*>(workspace);
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, cu_seqlens_size);
launchParams.cu_kv_seq_lens = reinterpret_cast<int*>(workspace);
@ -281,6 +286,14 @@ void buildXQALaunchParams(XQALaunchParam<KVCacheBuffer>& launchParams, void*& in
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm1_scale_size);
launchParams.bmm2_scale_ptr = reinterpret_cast<float*>(workspace);
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm2_scale_size);
// Used for block sparse attention
if (params.use_sparse_attention)
{
launchParams.sparse_kv_block_offsets = reinterpret_cast<void*>(workspace);
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, kv_block_offsets_size);
launchParams.sparse_seq_lengths = reinterpret_cast<int*>(workspace);
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, seq_lengths_size);
}
inputScratch = workspace;
if (hasOutputScratch)
{

View File

@ -17,6 +17,7 @@
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
namespace tensorrt_llm
{
@ -109,6 +110,10 @@ struct XQAParams
// for cross attention
int32_t const* encoder_input_lengths = nullptr;
// sparse attention parameters
SparseAttentionParams sparse_params;
bool use_sparse_attention = false;
cudaStream_t stream = 0;
std::string toString() const
@ -179,6 +184,8 @@ struct XQAParams
<< "is_fp8_output :" << (is_fp8_output ? "true" : "false") << std ::endl
<< "fp8_out_scale :" << fp8_out_scale << std ::endl
<< "encoder_input_lengths: " << encoder_input_lengths << std::endl
<< "sparse_params: " << sparse_params.toString() << std::endl
<< "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl
<< "stream :" << stream;
return ss.str();

View File

@ -0,0 +1,132 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
#include <cub/cub.cuh>
namespace tensorrt_llm
{
namespace kernels
{
template <int THREADS_PER_BLOCK>
__global__ void gatherKvPageOffsetsKernel(
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
int32_t const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq]
int32_t const* seq_lengths, // [batch_size]
SparseAttentionParams const sparse_params, int32_t const batch_size, int32_t const tokens_per_page,
int32_t const max_num_pages_per_seq)
{
// Each CUDA block processes one sequence from the batch for one head.
int32_t const head_idx = blockIdx.x;
int32_t const batch_idx = blockIdx.y;
if (batch_idx >= batch_size)
{
return;
}
// Shared memory for reduction.
__shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
// Get the range of sparse indices and the sequence length.
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size];
int32_t const num_sparse_pages = end_offset - start_offset;
int32_t const original_seq_len = seq_lengths[batch_idx];
// Get global sparse index.
int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
size_t const dst_base_offset = (size_t) head_idx * batch_size * 2 * max_num_pages_per_seq + src_base_offset;
// Initialize the local max page index and number of valid pages.
int32_t local_max_page_index = -1;
int32_t local_num_valid_pages = 0;
// Perform the gather operation.
for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
{
// Get the source idx and offset.
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
if (src_idx < 0)
{
continue;
}
// Update the local max page index.
local_max_page_index = max(local_max_page_index, src_idx);
local_num_valid_pages++;
// Get the source and destination offsets.
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
// Perform the gather operation: read from the sparse location and write to the dense location.
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
}
// Reduce the local max page indices and number of valid pages.
Pair local_pair = {local_max_page_index, local_num_valid_pages};
Pair result = cub::BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage).Reduce(local_pair, PairReduceOp());
// Update sequence length for this head and batch.
if (threadIdx.x == 0)
{
int32_t const max_page_index = result.max_val;
int32_t const num_valid_pages = result.sum_val;
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
if (num_valid_pages > 0)
{
int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
int32_t seq_len_remain = original_seq_len % tokens_per_page;
if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0)
{
seq_len += tokens_per_page - seq_len_remain;
}
output_seq_lengths[seq_len_offset] = seq_len;
}
else
{
output_seq_lengths[seq_len_offset] = 0;
}
}
}
// Host-side launcher function
void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_seq_lengths,
int32_t const* kv_page_offsets, int32_t const* seq_lengths, SparseAttentionParams const sparse_params,
int32_t const batch_size, int32_t const num_head_kv, int32_t const tokens_per_page,
int32_t const max_num_pages_per_seq, cudaStream_t stream)
{
// The grid.
dim3 grid(num_head_kv, batch_size, 1);
// The block.
dim3 block(256, 1, 1);
// Shared memory size.
size_t smem_size = sizeof(Pair) * 256;
// Launch the kernel.
gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
}
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,81 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <cuda_runtime.h>
#include <sstream>
#include <string>
#include <tuple>
namespace tensorrt_llm
{
namespace kernels
{
struct SparseAttentionParams
{
int32_t* sparse_kv_indices{nullptr}; // [num_kv_heads, num_sparse_kv_indices]
int32_t* sparse_attn_indices{nullptr}; // [num_kv_heads, num_sparse_attn_indices]
int32_t* sparse_kv_offsets{nullptr}; // [num_contexts + 1]
int32_t* sparse_attn_offsets{nullptr}; // [num_generations + 1]
std::string toString() const
{
std::stringstream ss;
ss << "sparse_kv_indices: " << this->sparse_kv_indices << std::endl
<< "sparse_attn_indices: " << this->sparse_attn_indices << std::endl
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl;
return ss.str();
}
auto data() const
{
return std::make_tuple(sparse_kv_indices, sparse_attn_indices, sparse_kv_offsets, sparse_attn_offsets);
}
};
struct Pair
{
int32_t max_val;
int32_t sum_val;
};
struct PairReduceOp
{
#if defined(__CUDACC__)
inline __device__
#endif
Pair
operator()(Pair const& a, Pair const& b) const
{
Pair result;
result.max_val = a.max_val > b.max_val ? a.max_val : b.max_val;
result.sum_val = a.sum_val + b.sum_val;
return result;
}
};
void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
int32_t const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq]
int32_t const* seq_lengths, // [batch_size]
SparseAttentionParams const sparse_params, int32_t const batch_size, int32_t const num_head_kv,
int32_t const tokens_per_page, int32_t const max_num_pages_per_seq, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -207,8 +207,6 @@ public:
// Prepare the kernel parameters.
auto kernelParams = KernelParams::setKernelParams(params, kernelMeta, maxNumCtasQ, maxNumCtasKv);
// TODO: set the block sparse attention flag.
kernelParams.mUseBlockSparseAttention = false;
// Prepare kernel parameters list for cuLaunchKernelEx.
void* kernelParamsList[] = {&kernelParams};

View File

@ -148,6 +148,10 @@ struct QKVPreprocessingParams
int const* cu_seq_lens{nullptr};
// list of cumulative KV sequence lengths, of shape {batch_size + 1}, used by cross attention only.
int const* cu_kv_seq_lens{nullptr};
// list of cumulative length of sparse KV indices, of shape {batch_size + 1}
int const* sparse_kv_offsets{nullptr};
// list of sparse KV indices for writing to KV cache, of shape {num_kv_heads, num_sparse_kv_indices}
int const* sparse_kv_indices{nullptr};
// inverse frequencies (angle raised at various powers) from the RoPE formula
// shape of {batch_size , rotaryEmbeddingDim / 2}
float const* rotary_embedding_inv_freq{nullptr};
@ -167,6 +171,7 @@ struct QKVPreprocessingParams
int sink_token_len{0};
int token_num{0};
bool remove_padding{true};
bool is_last_chunk{true};
bool cross_attention{false};
int head_num{0};
int kv_head_num{0};
@ -216,24 +221,48 @@ struct QKVPreprocessingParams
ss << "kv_cache_block_scales_buffer: " << kv_cache_block_scales_buffer.data << std::endl;
ss << "qkv_bias: " << qkv_bias << std::endl;
ss << "tokens_info: " << tokens_info << std::endl;
ss << "seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
ss << "cache_seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) cache_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
ss << "encoder_seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) encoder_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
ss << "cu_seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) cu_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
ss << "cu_kv_seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) cu_kv_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
ss << "rotary_embedding_inv_freq: "
<< *(runtime::ITensor::wrap((void*) rotary_embedding_inv_freq, nvinfer1::DataType::kFLOAT,
runtime::ITensor::makeShape({batch_size, rotary_embedding_dim / 2})));
if (seq_lens && batch_size > 0)
{
ss << "seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
}
if (cache_seq_lens && batch_size > 0)
{
ss << "cache_seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) cache_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
}
if (encoder_seq_lens && batch_size > 0)
{
ss << "encoder_seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) encoder_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
}
if (cu_seq_lens && batch_size > 0)
{
ss << "cu_seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) cu_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
}
if (cu_kv_seq_lens && batch_size > 0)
{
ss << "cu_kv_seq_lens: "
<< *(runtime::ITensor::wrap(
(void*) cu_kv_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
}
if (sparse_kv_offsets)
{
ss << "sparse_kv_offsets: "
<< *(runtime::ITensor::wrap((void*) sparse_kv_offsets, nvinfer1::DataType::kINT32,
runtime::ITensor::makeShape({batch_size + 1})));
}
if (rotary_embedding_inv_freq && batch_size > 0 && rotary_embedding_dim > 0)
{
ss << "rotary_embedding_inv_freq: "
<< *(runtime::ITensor::wrap((void*) rotary_embedding_inv_freq, nvinfer1::DataType::kFLOAT,
runtime::ITensor::makeShape({batch_size, rotary_embedding_dim / 2})));
}
ss << "rotary_coef_cache_buffer: " << rotary_coef_cache_buffer << std::endl;
ss << "qkv_scale_orig_quant: " << qkv_scale_orig_quant << std::endl;
ss << "spec_decoding_position_offsets: " << spec_decoding_position_offsets << std::endl;
@ -244,6 +273,7 @@ struct QKVPreprocessingParams
ss << "sink_token_len: " << sink_token_len << std::endl;
ss << "token_num: " << token_num << std::endl;
ss << "remove_padding: " << remove_padding << std::endl;
ss << "is_last_chunk: " << is_last_chunk << std::endl;
ss << "cross_attention: " << cross_attention << std::endl;
ss << "head_num: " << head_num << std::endl;
ss << "kv_head_num: " << kv_head_num << std::endl;
@ -362,23 +392,36 @@ void invokeQKVPreprocessing(QKVPreprocessingParams<T, KVCacheBuffer> params, cud
template <typename T, typename T_cache, typename KVCacheBuffer>
void invokeUpdateCyclicKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream);
template <typename T, typename T_cache, typename KVCacheBuffer>
void invokeUpdateSparseKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream);
// Debug function to test basic parameter access
template <typename T, typename KVCacheBuffer>
void invokeDebugSparseKvCacheParams(
QKVPreprocessingParams<T, KVCacheBuffer> params, int* debug_output, cudaStream_t stream);
template <typename T, typename KVCacheBuffer>
void invokeKvCachePostprocessing(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream)
{
params.setCommonParameters();
if (params.cache_type == KvCacheDataType::INT8)
// handle sparse KV cache update if needed
if (params.sparse_kv_indices != nullptr && params.sparse_kv_offsets != nullptr && params.is_last_chunk)
{
invokeUpdateCyclicKvCacheAfterFmha<T, int8_t, KVCacheBuffer>(params, stream);
}
if (params.cache_type == KvCacheDataType::INT8)
{
invokeUpdateSparseKvCacheAfterFmha<T, int8_t, KVCacheBuffer>(params, stream);
}
#ifdef ENABLE_FP8
else if (params.cache_type == KvCacheDataType::FP8)
{
invokeUpdateCyclicKvCacheAfterFmha<T, __nv_fp8_e4m3, KVCacheBuffer>(params, stream);
}
else if (params.cache_type == KvCacheDataType::FP8)
{
invokeUpdateSparseKvCacheAfterFmha<T, __nv_fp8_e4m3, KVCacheBuffer>(params, stream);
}
#endif // ENABLE_FP8
else
{
invokeUpdateCyclicKvCacheAfterFmha<T, T, KVCacheBuffer>(params, stream);
else
{
invokeUpdateSparseKvCacheAfterFmha<T, T, KVCacheBuffer>(params, stream);
}
}
}

View File

@ -1709,6 +1709,130 @@ void invokeUpdateCyclicKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, typename TCache, int BLOCK_SIZE, int Dh, typename KVCacheBuffer>
__global__ __launch_bounds__(BLOCK_SIZE) void updateSparseKvCacheAfterFmha(
QKVPreprocessingParams<T, KVCacheBuffer> params)
{
// The number of 16B vectors per head size in the kv cache.
constexpr int VECS_PER_HEAD = Dh * sizeof(TCache) / 16;
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
int const batch_idx = blockIdx.z;
int const kv_head_idx = blockIdx.y;
int const total_num_sparse_kv_tokens = params.sparse_kv_offsets[params.batch_size];
int const sparse_start_idx = params.sparse_kv_offsets[batch_idx];
int const sparse_end_idx = params.sparse_kv_offsets[batch_idx + 1];
int const num_sparse_tokens = sparse_end_idx - sparse_start_idx;
int const tokens_per_block = blockDim.y;
int const vecs_per_block = blockDim.x;
extern __shared__ uint4 smem[];
uint4* k_smem = smem;
uint4* v_smem = k_smem + tokens_per_block * VECS_PER_HEAD;
for (int token_block_offset = 0; token_block_offset < num_sparse_tokens; token_block_offset += tokens_per_block)
{
int const sparse_token_offset = token_block_offset + threadIdx.y;
if (sparse_token_offset < num_sparse_tokens)
{
int const global_sparse_idx = sparse_start_idx + sparse_token_offset;
int const sparse_idx_offset = kv_head_idx * total_num_sparse_kv_tokens + global_sparse_idx;
int const src_token_idx = params.sparse_kv_indices[sparse_idx_offset];
void* src_k_ptr = params.kv_cache_buffer.getKBlockPtr(batch_idx, src_token_idx);
void* src_v_ptr = params.kv_cache_buffer.getVBlockPtr(batch_idx, src_token_idx);
auto const src_k_block_ptr = reinterpret_cast<uint4*>(src_k_ptr);
auto const src_v_block_ptr = reinterpret_cast<uint4*>(src_v_ptr);
for (int head_vec_idx = threadIdx.x; head_vec_idx < VECS_PER_HEAD; head_vec_idx += vecs_per_block)
{
auto const src_k_vec_idx
= params.kv_cache_buffer.getKVLocalIdx(src_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx);
auto const src_v_vec_idx
= params.kv_cache_buffer.getKVLocalIdx(src_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx);
k_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx] = src_k_block_ptr[src_k_vec_idx];
v_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx] = src_v_block_ptr[src_v_vec_idx];
}
}
__syncthreads();
if (sparse_token_offset < num_sparse_tokens)
{
int const global_sparse_idx = sparse_start_idx + sparse_token_offset;
int const sparse_idx_offset = kv_head_idx * total_num_sparse_kv_tokens + global_sparse_idx;
int const src_token_idx = params.sparse_kv_indices[sparse_idx_offset];
int const dst_token_idx = sparse_token_offset;
if (src_token_idx != dst_token_idx)
{
void* dst_k_ptr = params.kv_cache_buffer.getKBlockPtr(batch_idx, dst_token_idx);
void* dst_v_ptr = params.kv_cache_buffer.getVBlockPtr(batch_idx, dst_token_idx);
auto const dst_k_block_ptr = reinterpret_cast<uint4*>(dst_k_ptr);
auto const dst_v_block_ptr = reinterpret_cast<uint4*>(dst_v_ptr);
for (int head_vec_idx = threadIdx.x; head_vec_idx < VECS_PER_HEAD; head_vec_idx += vecs_per_block)
{
auto const dst_k_vec_idx
= params.kv_cache_buffer.getKVLocalIdx(dst_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx);
auto const dst_v_vec_idx
= params.kv_cache_buffer.getKVLocalIdx(dst_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx);
dst_k_block_ptr[dst_k_vec_idx] = k_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx];
dst_v_block_ptr[dst_v_vec_idx] = v_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx];
}
}
}
__syncthreads();
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int Dh, typename T, typename TCache, typename KVCacheBuffer>
void kernelSparseDispatchHeadSize(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream)
{
constexpr int VECS_PER_HEAD = Dh * sizeof(TCache) / 16;
constexpr int BLOCK_SIZE = 1024;
dim3 block(32, 32); // x: head vectors, y: tokens
int smem_size = 2 * block.y * VECS_PER_HEAD * sizeof(uint4);
// grid.x is always 1 to avoid data races
dim3 grid(1, params.kv_head_num, params.batch_size);
updateSparseKvCacheAfterFmha<T, TCache, BLOCK_SIZE, Dh, KVCacheBuffer><<<grid, block, smem_size, stream>>>(params);
}
template <typename T, typename TCache, typename KVCacheBuffer>
void invokeUpdateSparseKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream)
{
if (params.sparse_kv_indices == nullptr)
{
return;
}
switch (params.size_per_head)
{
case 16: kernelSparseDispatchHeadSize<16, T, TCache, KVCacheBuffer>(params, stream); break;
case 32: kernelSparseDispatchHeadSize<32, T, TCache, KVCacheBuffer>(params, stream); break;
case 64: kernelSparseDispatchHeadSize<64, T, TCache, KVCacheBuffer>(params, stream); break;
case 128: kernelSparseDispatchHeadSize<128, T, TCache, KVCacheBuffer>(params, stream); break;
case 256: kernelSparseDispatchHeadSize<256, T, TCache, KVCacheBuffer>(params, stream); break;
default:
TLLM_CHECK_WITH_INFO(
false, "updateSparseKvCacheAfterFmha kernel doesn't support head size = %d", params.size_per_head);
break;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define INSTANTIATE_ATTENTION_INPUT_PROCESSING(T, TCache, KVCacheBuffer) \
template void invokeApplyBiasRopeUpdateKVCacheDispatch<T, TCache, KVCacheBuffer>( \
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream);
@ -1717,9 +1841,10 @@ void invokeUpdateCyclicKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer>
template void invokeApplyBiasRopeUpdateKVCacheDispatch<T, TCache, KVCacheBuffer>( \
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream); \
template void invokeUpdateCyclicKvCacheAfterFmha<T, TCache, KVCacheBuffer>( \
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream);
////////////////////////////////////////////////////////////////////////////////////////////////////
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream); \
template void invokeUpdateSparseKvCacheAfterFmha<T, TCache, KVCacheBuffer>( \
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream); \
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -17,6 +17,7 @@
#include "xqaDispatcher.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
#include <cstdint>
@ -404,9 +405,15 @@ void XqaDispatcher::runImpl(
// Otherwise, always enable the persistent scheduler for better performance.
tllmRunnerParams.mTileScheduler = params.multi_block_mode ? TileScheduler::Static : TileScheduler::Persistent;
// The sequence lengths for K/V.
tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;
// Q buffer.
tllmRunnerParams.qPtr = xqa_q_input_ptr;
// KV buffer
// Use block sparse attention.
tllmRunnerParams.mUseBlockSparseAttention = false;
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
{
// Paged KV
@ -416,9 +423,24 @@ void XqaDispatcher::runImpl(
tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<KVCacheIndex::UnderlyingType const*>(kv_cache_buffer.data);
tllmRunnerParams.mMaxNumPagesPerSeqKv = kv_cache_buffer.mMaxBlocksPerSeq;
tllmRunnerParams.mNumTokensPerPage = kv_cache_buffer.mTokensPerBlock;
// Gather kv page offsets for sparse attention.
if (params.use_sparse_attention)
{
invokeGatherKvPageOffsets(reinterpret_cast<int32_t*>(launchParams.sparse_kv_block_offsets),
launchParams.sparse_seq_lengths, reinterpret_cast<int32_t const*>(kv_cache_buffer.data),
params.sequence_lengths, params.sparse_params, batch_beam_size, num_kv_heads,
kv_cache_buffer.mTokensPerBlock, kv_cache_buffer.mMaxBlocksPerSeq, params.stream);
sync_check_cuda_error(params.stream);
tllmRunnerParams.seqLensKvPtr = launchParams.sparse_seq_lengths;
tllmRunnerParams.kvPageIdxPtr
= reinterpret_cast<KVCacheIndex::UnderlyingType const*>(launchParams.sparse_kv_block_offsets);
tllmRunnerParams.mUseBlockSparseAttention = true;
}
}
else
{
TLLM_CHECK_WITH_INFO(!params.use_sparse_attention, "Sparse attention is not supported for KVLinearBuffer.");
static_assert(std::is_same_v<KVCacheBuffer, KVLinearBuffer>);
// Contiguous KV
tllmRunnerParams.mQkvLayout = QkvLayout::ContiguousKv;
@ -437,8 +459,6 @@ void XqaDispatcher::runImpl(
tllmRunnerParams.scaleSoftmaxLog2Ptr
= reinterpret_cast<float const*>(launchParams.bmm1_scale_ptr + kIdxScaleSoftmaxLog2Ptr);
tllmRunnerParams.oSfScalePtr = params.fp4_out_sf_scale;
// The sequence lengths for K/V.
tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;
tllmRunnerParams.oPtr = params.output;
tllmRunnerParams.oSfPtr = params.output_sf;

View File

@ -55,7 +55,7 @@ void initBindings(nb::module_& m)
nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt,
nb::arg("mla_tensor_params"), nb::arg("attention_chunk_size") = std::nullopt,
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation",
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_attention_params"), "Multi-head attention operation",
nb::call_guard<nb::gil_scoped_release>());
}
} // namespace tensorrt_llm::nanobind::thop

View File

@ -572,11 +572,14 @@ size_t GPTAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* in
= isCrossAttention() ? cross_kv_length : (useKVCache() ? inputs[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0);
int const max_num_tokens
= mRemovePadding ? inputs[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0] : max_num_seq * max_context_length;
int const max_blocks_per_sequence
= (useKVCache() && mPagedKVCache) ? inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims.d[3] : 0;
size_t const context_workspace_size
= getWorkspaceSizeForContext(type, max_num_seq, max_context_length, cross_kv_length, max_num_tokens);
size_t const generation_workspace_size
= getWorkspaceSizeForGeneration(type, max_num_seq, max_kv_cache_length, max_num_tokens);
size_t const generation_workspace_size = getWorkspaceSizeForGeneration(
type, max_num_seq, max_kv_cache_length, max_num_tokens, max_blocks_per_sequence);
size_t attention_input_workspace_size = 0;

View File

@ -55,7 +55,7 @@ void initBindings(pybind11::module_& m)
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
py::arg("mla_tensor_params"), py::arg("attention_chunk_size") = std::nullopt,
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
py::arg("spec_decoding_tensor_params"), "Multi-head attention operation",
py::arg("spec_decoding_tensor_params"), py::arg("sparse_attention_params"), "Multi-head attention operation",
py::call_guard<py::gil_scoped_release>());
}
} // namespace tensorrt_llm::pybind::thop

View File

@ -19,6 +19,7 @@
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/mlaKernels.h"
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include "tensorrt_llm/thop/attentionOp.h"
@ -38,6 +39,7 @@ namespace trtllm::attention
{
using tensorrt_llm::kernels::KVBlockArray;
using tensorrt_llm::kernels::MlaParams;
using tensorrt_llm::kernels::SparseAttentionParams;
enum class AttentionInputType : int8_t
{
@ -62,7 +64,7 @@ public:
virtual ~RunnerBase() = default;
virtual void prepare(AttentionOp& op) const = 0;
virtual int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size,
int const num_gen_tokens) const
int const num_gen_tokens, int const max_blocks_per_sequence) const
= 0;
// typically, we use single qkv input, but for context MLA, we use separate qkv inputs
virtual void run(AttentionOp& op, bool const is_context, int32_t const seq_offset, int32_t const num_seqs,
@ -82,7 +84,9 @@ public:
std::vector<std::optional<torch::Tensor>> mla_tensor_params,
torch::optional<torch::Tensor> softmax_stats_tensor,
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks) const
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
torch::optional<torch::Tensor> sparse_attn_offsets) const
= 0;
};
@ -110,12 +114,12 @@ public:
}
int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size,
int const num_gen_tokens) const override
int const num_gen_tokens, int const max_blocks_per_sequence) const override
{
size_t const context_workspace_size
= op.getWorkspaceSizeForContext(op.mType, max_num_requests, op.mMaxContextLength, 0, num_tokens);
size_t const generation_workspace_size
= op.getWorkspaceSizeForGeneration(op.mType, max_num_requests, max_attention_window_size, num_gen_tokens);
size_t const generation_workspace_size = op.getWorkspaceSizeForGeneration(
op.mType, max_num_requests, max_attention_window_size, num_gen_tokens, max_blocks_per_sequence);
return std::max(context_workspace_size, generation_workspace_size);
}
@ -137,7 +141,9 @@ public:
std::vector<std::optional<torch::Tensor>> mla_tensor_params,
torch::optional<torch::Tensor> softmax_stats_tensor,
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks) const override
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
torch::optional<torch::Tensor> sparse_attn_offsets) const override
{
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
@ -221,6 +227,22 @@ public:
}
}
// Prepare sparse attention parameters
if (is_context)
{
op.mRuntimeSparseAttentionParams.sparse_kv_indices
= sparse_kv_indices.has_value() ? sparse_kv_indices.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_kv_offsets
= sparse_kv_offsets.has_value() ? sparse_kv_offsets.value().data_ptr<int32_t>() : nullptr;
}
else
{
op.mRuntimeSparseAttentionParams.sparse_attn_indices
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_attn_offsets
= sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
}
int const* context_lengths_ptr = context_lengths.slice(0, seq_offset).data_ptr<int>();
int const* sequence_lengths_ptr = sequence_length.slice(0, seq_offset).data_ptr<int>();
// Note we still need context length during generation for MMHA optimization.
@ -518,8 +540,16 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t> attention_chunk_size,
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params)
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
std::vector<std::optional<torch::Tensor>> sparse_attention_params)
{
// Decompress sparse attention parameters
TORCH_CHECK(sparse_attention_params.size() == 4, "Expected 4 sparse attention parameters");
torch::optional<torch::Tensor> sparse_kv_indices = sparse_attention_params[0];
torch::optional<torch::Tensor> sparse_kv_offsets = sparse_attention_params[1];
torch::optional<torch::Tensor> sparse_attn_indices = sparse_attention_params[2];
torch::optional<torch::Tensor> sparse_attn_offsets = sparse_attention_params[3];
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
// Use these tensors to infer if the attention is using KV cache
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value()
@ -633,6 +663,18 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding
op->mIsSpecDecTree = spec_decoding_bool_params[2]; // is_spec_dec_tree
op->mUseSparseAttention = false;
op->mUseTllmGenSparseAttention = false;
if ((sparse_kv_indices.has_value() && sparse_kv_indices.value().numel() > 0)
|| (sparse_attn_indices.has_value() && sparse_attn_indices.value().numel() > 0))
{
op->mUseSparseAttention = true;
if (sparse_attn_indices.has_value() && sparse_attn_indices.value().numel() > 0)
{
op->mUseTllmGenSparseAttention = true;
}
}
if (is_mla_enable)
{
// MLA does not support NVFP4 output yet.
@ -718,7 +760,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
int32_t const max_attention_window_size
= beam_width == 1 ? attention_window_size : cache_indirection.value().size(2);
int64_t const workspace_size = runner->getWorkspaceSize(*op, num_tokens, max_attention_window_size, num_gen_tokens);
int32_t const max_blocks_per_sequence
= use_kv_cache && kv_cache_block_offsets.has_value() ? kv_cache_block_offsets.value().size(-1) : 0;
int64_t const workspace_size
= runner->getWorkspaceSize(*op, num_tokens, max_attention_window_size, num_gen_tokens, max_blocks_per_sequence);
TLLM_LOG_TRACE("Expected workspace size is %ld bytes", workspace_size);
if (workspace_size >= (16l << 30))
@ -760,7 +805,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks);
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets);
}
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@ -777,7 +822,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks);
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets);
}
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);

View File

@ -57,9 +57,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,
std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t> attention_chunk_size,
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params);
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
std::vector<std::optional<torch::Tensor>> sparse_attention_params);
} // namespace torch_ext

View File

@ -40,6 +40,7 @@ add_gtest(smoothQuantKernelTest smoothQuant/smoothQuantKernelTest.cpp)
add_gtest(stopCriteriaKernelsTest stopCriteriaKernelsTest.cpp)
add_gtest(weightOnlyKernelTest weightOnly/weightOnlyKernelTest.cpp)
add_gtest(mlaPreprocessTest mlaPreprocessTest.cu)
add_gtest(sparseAttentionKernelsTest sparseAttentionKernelsTest.cpp)
add_gtest(cudaCoreGemmKernelTest cudaCoreGemm/cudaCoreGemmKernelTest.cpp)
@ -90,3 +91,4 @@ add_gtest(routingKernelsTest "${ROUTING_KERNEL_TEST_SRC}")
add_gtest(moeLoadBalanceKernelTest moeLoadBalanceKernelTest.cpp)
add_gtest(eaglePackDataTest eaglePackDataTest.cpp)
add_gtest(sparseKvCacheTest sparseKvCacheTest.cu)

View File

@ -0,0 +1,191 @@
#include <gtest/gtest.h>
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include <memory>
#include <vector>
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
namespace
{
class sparseAttentionKernelsTest : public ::testing::Test
{
public:
void SetUp() override
{
mStream = std::make_shared<CudaStream>();
mBufferManager = std::make_shared<BufferManager>(mStream);
}
void TearDown() override {}
protected:
std::shared_ptr<CudaStream> mStream;
std::shared_ptr<BufferManager> mBufferManager;
};
TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
{
// Test parameters
constexpr int max_batch_size = 4;
constexpr int batch_size = 2;
constexpr int num_head_kv = 4;
constexpr int max_num_pages_per_seq = 8;
constexpr int tokens_per_page = 64;
constexpr int total_sparse_pages = max_batch_size * max_num_pages_per_seq; // Total sparse pages across all batches
// Create input buffers
auto kv_page_offsets
= mBufferManager->gpu(ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
auto seq_lengths = mBufferManager->gpu(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
auto sparse_indices
= mBufferManager->gpu(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
auto sparse_indices_offsets = mBufferManager->gpu(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
// Create output buffers
auto output_kv_page_offsets = mBufferManager->gpu(
ITensor::makeShape({num_head_kv, batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
auto output_seq_lengths
= mBufferManager->gpu(ITensor::makeShape({num_head_kv, batch_size}), nvinfer1::DataType::kINT32);
// Create pinned host buffers for data initialization
auto kv_page_offsets_host = mBufferManager->pinned(
ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
auto seq_lengths_host = mBufferManager->pinned(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
auto sparse_indices_host
= mBufferManager->pinned(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
auto sparse_indices_offsets_host
= mBufferManager->pinned(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
// Initialize test data
auto kv_page_offsets_ptr = bufferCast<int32_t>(*kv_page_offsets_host);
auto seq_lengths_ptr = bufferCast<int>(*seq_lengths_host);
auto sparse_indices_ptr = bufferCast<int>(*sparse_indices_host);
auto sparse_indices_offsets_ptr = bufferCast<int>(*sparse_indices_offsets_host);
// Initialize KV page offsets with test data
for (int b = 0; b < batch_size; ++b)
{
for (int d = 0; d < 2; ++d)
{
for (int p = 0; p < max_num_pages_per_seq; ++p)
{
int offset = b * 2 * max_num_pages_per_seq + d * max_num_pages_per_seq + p;
kv_page_offsets_ptr[offset] = 1000 + b * 100 + d * 10 + p;
}
}
}
// Initialize sequence lengths
seq_lengths_ptr[0] = 2 * tokens_per_page + 18; // 3 pages for batch 0
seq_lengths_ptr[1] = 3 * tokens_per_page + 3; // 4 pages for batch 1
// Initialize sparse indices with different patterns for different heads
// Shape: {total_sparse_pages, num_head_kv}
// Each head can have its own sparse pattern
int num_sparse_pages = 5;
int sparse_page_indices[5][4] = {{1, 0, 0, 1}, {2, 1, 1, -1}, {-1, 2, -1, -1}, {0, 1, 2, 3}, {3, 3, 3, -1}};
for (int page = 0; page < num_sparse_pages; ++page)
{
for (int head = 0; head < num_head_kv; ++head)
{
int idx = head * num_sparse_pages + page;
sparse_indices_ptr[idx] = sparse_page_indices[page][head];
}
}
// Initialize sparse indices offsets
sparse_indices_offsets_ptr[0] = 0; // Start of batch 0
sparse_indices_offsets_ptr[1] = 3; // Start of batch 1 (3 sparse pages for batch 0)
sparse_indices_offsets_ptr[2] = 5; // End (3 sparse pages for batch 1)
// Copy data to GPU
mBufferManager->copy(*kv_page_offsets_host, *kv_page_offsets);
mBufferManager->copy(*seq_lengths_host, *seq_lengths);
mBufferManager->copy(*sparse_indices_host, *sparse_indices);
mBufferManager->copy(*sparse_indices_offsets_host, *sparse_indices_offsets);
SparseAttentionParams sparse_params;
sparse_params.sparse_attn_indices = bufferCast<int32_t>(*sparse_indices);
sparse_params.sparse_attn_offsets = bufferCast<int32_t>(*sparse_indices_offsets);
// Launch the kernel
invokeGatherKvPageOffsets(bufferCast<int32_t>(*output_kv_page_offsets), bufferCast<int32_t>(*output_seq_lengths),
bufferCast<int32_t>(*kv_page_offsets), bufferCast<int32_t>(*seq_lengths), sparse_params, batch_size,
num_head_kv, tokens_per_page, max_num_pages_per_seq, mStream->get());
// Wait for completion
mStream->synchronize();
// Copy results back to host for verification
auto output_kv_page_offsets_host = mBufferManager->pinned(
ITensor::makeShape({num_head_kv, batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
auto output_seq_lengths_host
= mBufferManager->pinned(ITensor::makeShape({num_head_kv, batch_size}), nvinfer1::DataType::kINT32);
mBufferManager->copy(*output_kv_page_offsets, *output_kv_page_offsets_host);
mBufferManager->copy(*output_seq_lengths, *output_seq_lengths_host);
// Wait for completion
mStream->synchronize();
auto output_kv_offsets_ptr = bufferCast<int32_t>(*output_kv_page_offsets_host);
auto output_seq_len_ptr = bufferCast<int>(*output_seq_lengths_host);
// Verify sequence lengths for each head and batch
int expected_seq_lens[4][2] = {
{tokens_per_page + 18, tokens_per_page + 3}, // Head 0: batch 0 has 2 pages, batch 1 has 0 pages
{2 * tokens_per_page + 18, tokens_per_page + 3}, // Head 1: batch 0 has 3 pages, batch 1 has 0 pages
{2 * tokens_per_page, tokens_per_page + 3}, // Head 2: batch 0 has 2 pages, batch 1 has 0 pages
{tokens_per_page, 3} // Head 3: batch 0 has 2 pages, batch 1 has 0 pages
};
for (int h = 0; h < num_head_kv; ++h)
{
for (int b = 0; b < batch_size; ++b)
{
int seq_len_idx = h * batch_size + b;
EXPECT_EQ(output_seq_len_ptr[seq_len_idx], expected_seq_lens[h][b])
<< "Sequence length mismatch at head=" << h << ", batch=" << b;
}
}
// Verify gathered KV page offsets
for (int h = 0; h < num_head_kv; ++h)
{
for (int b = 0; b < batch_size; ++b)
{
int num_sparse_pages_batch = sparse_indices_offsets_ptr[b + 1] - sparse_indices_offsets_ptr[b];
for (int d = 0; d < 2; ++d)
{
for (int p = 0; p < num_sparse_pages_batch; ++p)
{
// Calculate expected value (from the sparse index)
int sparse_idx_global = sparse_indices_offsets_ptr[b] + p;
int src_page_idx
= sparse_indices_ptr[h * sparse_indices_offsets_ptr[batch_size] + sparse_idx_global];
if (src_page_idx == -1)
{
continue; // Skip invalid indices
}
// Calculate output offset
size_t output_offset = h * batch_size * 2 * max_num_pages_per_seq + b * 2 * max_num_pages_per_seq
+ d * max_num_pages_per_seq + p;
int expected_value = 1000 + b * 100 + d * 10 + src_page_idx;
EXPECT_EQ(output_kv_offsets_ptr[output_offset], expected_value)
<< "Mismatch at head=" << h << ", batch=" << b << ", dim=" << d << ", page=" << p;
}
}
}
}
}
} // namespace

View File

@ -0,0 +1,521 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <memory>
#include <vector>
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/kvCacheUtils.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
using namespace tensorrt_llm::kernels;
class SparseKvCacheTest : public ::testing::Test
{
protected:
void SetUp() override
{
mStream = nullptr;
TLLM_CUDA_CHECK(cudaStreamCreate(&mStream));
// Test parameters
mBatchSize = 2;
mNumKvHeads = 4;
mHeadSize = 128;
mMaxSeqLen = 512;
mTokensPerBlock = 64;
mMaxBlocksPerSeq = 8;
// Allocate test data
setupTestData();
}
void TearDown() override
{
cleanup();
if (mStream)
{
TLLM_CUDA_CHECK(cudaStreamDestroy(mStream));
}
}
void setupTestData()
{
// Allocate device memory for sparse KV indices and offsets
size_t sparse_indices_size = 20 * mNumKvHeads * sizeof(int); // 20 sparse tokens max
size_t sparse_offsets_size = (mBatchSize + 1) * sizeof(int);
TLLM_CUDA_CHECK(cudaMalloc(&mSparseKvIndicesDevice, sparse_indices_size));
TLLM_CUDA_CHECK(cudaMalloc(&mSparseKvOffsetsDevice, sparse_offsets_size));
TLLM_CUDA_CHECK(cudaMalloc(&mSeqLensDevice, mBatchSize * sizeof(int)));
TLLM_CUDA_CHECK(cudaMalloc(&mCacheSeqLensDevice, mBatchSize * sizeof(int)));
// Create sparse indices in the correct format: [sparse_token_idx][head_idx]
// Total sparse tokens: 5 (batch 0) + 3 (batch 1) = 8
std::vector<int> sparseKvIndicesHost;
// Batch 0: 5 sparse tokens per head
std::vector<std::vector<int>> batch0_indices = {
{1, 2, 3, 4, 5}, // head 0
{3, 4, 5, 6, 8}, // head 1
{0, 1, 3, 5, 8}, // head 2
{1, 3, 5, 10, 11} // head 3
};
// Batch 1: 3 sparse tokens per head
std::vector<std::vector<int>> batch1_indices = {
{1, 4, 7}, // head 0
{0, 2, 3}, // head 1
{1, 2, 7}, // head 2
{1, 3, 7} // head 3
};
// [num_kv_heads, num_sparse_kv_tokens]
for (int head = 0; head < mNumKvHeads; ++head)
{
for (size_t token = 0; token < batch0_indices[head].size(); ++token)
{
sparseKvIndicesHost.push_back(batch0_indices[head][token]);
}
for (size_t token = 0; token < batch1_indices[head].size(); ++token)
{
sparseKvIndicesHost.push_back(batch1_indices[head][token]);
}
}
std::vector<int> sparseKvOffsetsHost = {0, 5, 8}; // Batch 0: 5 tokens, Batch 1: 3 tokens
std::vector<int> seqLensHost = {12, 8}; // Original sequence lengths
std::vector<int> cacheSeqLensHost = {12, 8}; // Cache sequence lengths
TLLM_CUDA_CHECK(cudaMemcpy(mSparseKvIndicesDevice, sparseKvIndicesHost.data(),
sparseKvIndicesHost.size() * sizeof(int), cudaMemcpyHostToDevice));
TLLM_CUDA_CHECK(cudaMemcpy(mSparseKvOffsetsDevice, sparseKvOffsetsHost.data(),
sparseKvOffsetsHost.size() * sizeof(int), cudaMemcpyHostToDevice));
TLLM_CUDA_CHECK(
cudaMemcpy(mSeqLensDevice, seqLensHost.data(), seqLensHost.size() * sizeof(int), cudaMemcpyHostToDevice));
TLLM_CUDA_CHECK(cudaMemcpy(mCacheSeqLensDevice, cacheSeqLensHost.data(), cacheSeqLensHost.size() * sizeof(int),
cudaMemcpyHostToDevice));
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
// Setup KV cache buffer using KVBlockArray
setupKvCacheBuffer();
}
void setupKvCacheBuffer()
{
// Calculate memory requirements
auto const elemSize = sizeof(half);
auto const sizePerToken = mNumKvHeads * mHeadSize * elemSize;
auto const bytesPerBlock = mTokensPerBlock * sizePerToken;
auto const totalBlocks = mBatchSize * mMaxBlocksPerSeq;
auto const poolSize = totalBlocks * bytesPerBlock * 2; // K and V
// Allocate primary pool
TLLM_CUDA_CHECK(cudaMalloc(&mKvCachePool, poolSize));
TLLM_CUDA_CHECK(cudaMemset(mKvCachePool, 0, poolSize));
// Allocate block offsets
size_t offsetsSize = mBatchSize * mMaxBlocksPerSeq * 2 * sizeof(KVCacheIndex);
TLLM_CUDA_CHECK(cudaMalloc(&mBlockOffsetsDevice, offsetsSize));
// Initialize block offsets (simple linear mapping for test)
std::vector<KVCacheIndex> blockOffsetsHost;
blockOffsetsHost.reserve(mBatchSize * mMaxBlocksPerSeq * 2);
for (int batch = 0; batch < mBatchSize; ++batch)
{
for (int block = 0; block < mMaxBlocksPerSeq; ++block)
{
// K cache block offset
int kBlockIdx = batch * mMaxBlocksPerSeq * 2 + block;
blockOffsetsHost.emplace_back(kBlockIdx, false);
}
for (int block = 0; block < mMaxBlocksPerSeq; ++block)
{
// V cache block offset
int vBlockIdx = batch * mMaxBlocksPerSeq * 2 + mMaxBlocksPerSeq + block;
blockOffsetsHost.emplace_back(vBlockIdx, false);
}
}
TLLM_CUDA_CHECK(cudaMemcpy(mBlockOffsetsDevice, blockOffsetsHost.data(),
blockOffsetsHost.size() * sizeof(KVCacheIndex), cudaMemcpyHostToDevice));
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
// Create KVBlockArray with correct parameter order:
// (batchSize, maxBlocksPerSeq, tokensPerBlock, bytesPerToken,
// maxAttentionWindow, maxAttentionWindowAllLayer, sinkTokenLen, canUseOneMoreBlock,
// primaryPoolPtr, secondaryPoolPtr, data)
mKvCacheBuffer = KVBlockArray(mBatchSize, mMaxBlocksPerSeq, mTokensPerBlock, sizePerToken, mMaxSeqLen,
mMaxSeqLen, 0, false, mKvCachePool, nullptr, mBlockOffsetsDevice);
}
void cleanup()
{
if (mSparseKvIndicesDevice)
cudaFree(mSparseKvIndicesDevice);
if (mSparseKvOffsetsDevice)
cudaFree(mSparseKvOffsetsDevice);
if (mSeqLensDevice)
cudaFree(mSeqLensDevice);
if (mCacheSeqLensDevice)
cudaFree(mCacheSeqLensDevice);
if (mKvCachePool)
cudaFree(mKvCachePool);
if (mBlockOffsetsDevice)
cudaFree(mBlockOffsetsDevice);
if (mQkvInputDevice)
cudaFree(mQkvInputDevice);
}
// Test parameters
int mBatchSize;
int mNumKvHeads;
int mHeadSize;
int mMaxSeqLen;
int mTokensPerBlock;
int mMaxBlocksPerSeq;
// Device memory
int* mSparseKvIndicesDevice = nullptr;
int* mSparseKvOffsetsDevice = nullptr;
int* mSeqLensDevice = nullptr;
int* mCacheSeqLensDevice = nullptr;
void* mKvCachePool = nullptr;
KVCacheIndex* mBlockOffsetsDevice = nullptr;
half* mQkvInputDevice = nullptr;
KVBlockArray mKvCacheBuffer;
cudaStream_t mStream;
// Helper functions for verification
bool verifySparseKvCacheMapping(std::vector<half> const& originalKvCache);
void performHostSparseMapping(std::vector<half> const& originalKvCache, std::vector<half>& expectedKvCache);
void extractKvCacheFromGpu(std::vector<half>& kvCacheHost);
void initializeKvCacheWithPattern();
};
TEST_F(SparseKvCacheTest, UpdateSparseKvCacheAfterFmha)
{
// Allocate dummy QKV input (normally this would come from the attention computation)
size_t qkvInputSize = mBatchSize * mMaxSeqLen * 3 * mNumKvHeads * mHeadSize * sizeof(half);
TLLM_CUDA_CHECK(cudaMalloc(&mQkvInputDevice, qkvInputSize));
// Initialize with test pattern
std::vector<half> qkvInputHost(qkvInputSize / sizeof(half));
for (size_t i = 0; i < qkvInputHost.size(); ++i)
{
qkvInputHost[i] = half(float(i % 1000) / 100.0f); // Simple test pattern
}
TLLM_CUDA_CHECK(cudaMemcpy(mQkvInputDevice, qkvInputHost.data(), qkvInputSize, cudaMemcpyHostToDevice));
// Initialize KV cache with initial data pattern for testing
initializeKvCacheWithPattern();
// Extract the original KV cache data before kernel execution for verification
size_t totalKvElements = mBatchSize * mMaxSeqLen * mNumKvHeads * mHeadSize * 2; // K and V
std::vector<half> originalKvCache(totalKvElements);
extractKvCacheFromGpu(originalKvCache);
// Setup QKVPreprocessingParams
QKVPreprocessingParams<half, KVBlockArray> params;
memset(&params, 0, sizeof(params));
params.qkv_input = mQkvInputDevice;
params.kv_cache_buffer = mKvCacheBuffer;
params.sparse_kv_indices = mSparseKvIndicesDevice;
params.sparse_kv_offsets = mSparseKvOffsetsDevice;
params.seq_lens = mSeqLensDevice;
params.cache_seq_lens = mCacheSeqLensDevice;
params.batch_size = mBatchSize;
params.head_num = mNumKvHeads; // For Q heads, assuming same as KV heads for this test
params.kv_head_num = mNumKvHeads;
params.size_per_head = mHeadSize;
params.cache_type = KvCacheDataType::BASE;
params.rotary_embedding_dim = 0; // No rotary embedding for this test
params.setCommonParameters();
// Verify sparse indices and offsets on host
std::vector<int> hostSparseIndices(8 * mNumKvHeads);
TLLM_CUDA_CHECK(cudaMemcpy(hostSparseIndices.data(), mSparseKvIndicesDevice, hostSparseIndices.size() * sizeof(int),
cudaMemcpyDeviceToHost));
std::vector<int> hostSparseOffsets(mBatchSize + 1);
TLLM_CUDA_CHECK(cudaMemcpy(hostSparseOffsets.data(), mSparseKvOffsetsDevice, hostSparseOffsets.size() * sizeof(int),
cudaMemcpyDeviceToHost));
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
cudaError_t pre_kernel_error = cudaGetLastError();
if (pre_kernel_error != cudaSuccess)
{
printf("Debug: CUDA error before kernel call: %s\n", cudaGetErrorString(pre_kernel_error));
}
invokeUpdateSparseKvCacheAfterFmha<half, half, KVBlockArray>(params, mStream);
TLLM_CUDA_CHECK(cudaStreamSynchronize(mStream));
cudaError_t post_kernel_error = cudaGetLastError();
if (post_kernel_error != cudaSuccess)
{
printf("Debug: CUDA error after kernel call: %s\n", cudaGetErrorString(post_kernel_error));
}
else
{
printf("Debug: Kernel call completed, no immediate CUDA errors\n");
}
// Verification: Compare GPU result with CPU reference implementation
EXPECT_TRUE(verifySparseKvCacheMapping(originalKvCache));
}
// Implementation of verification functions
bool SparseKvCacheTest::verifySparseKvCacheMapping(std::vector<half> const& originalKvCache)
{
// Perform host-side sparse mapping to get expected result
size_t totalKvElements = originalKvCache.size();
std::vector<half> expectedKvCache{originalKvCache};
performHostSparseMapping(originalKvCache, expectedKvCache);
// Extract actual result from GPU after sparse kernel execution
std::vector<half> actualKvCache(totalKvElements);
extractKvCacheFromGpu(actualKvCache);
// Compare results with tolerance - only for valid sparse tokens
float const tolerance = 1e-5f;
bool passed = true;
int errorCount = 0;
int const maxErrorsToShow = 10;
size_t totalValidElements = 0;
// Only compare the sparse tokens that should have been reorganized
std::vector<int> sparseTokenCounts = {5, 3}; // Batch 0: 5 tokens, Batch 1: 3 tokens
for (int batch = 0; batch < mBatchSize; ++batch)
{
int numSparseTokens = sparseTokenCounts[batch];
for (int kv = 0; kv < 2; ++kv) // K and V
{
for (int token = 0; token < numSparseTokens; ++token)
{
for (int head = 0; head < mNumKvHeads; ++head)
{
for (int dim = 0; dim < mHeadSize; ++dim)
{
// Calculate index in flat array
// Layout: [batch][kv][token][head][dim]
size_t idx = batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize)
+ kv * (mMaxSeqLen * mNumKvHeads * mHeadSize) + token * (mNumKvHeads * mHeadSize)
+ head * mHeadSize + dim;
if (idx < totalKvElements)
{
float expected = float(expectedKvCache[idx]);
float actual = float(actualKvCache[idx]);
float diff = std::abs(expected - actual);
if (diff > tolerance)
{
if (errorCount < maxErrorsToShow)
{
printf(
"Mismatch at batch=%d, kv=%d, token=%d, head=%d, dim=%d: expected %.6f, got "
"%.6f, diff %.6f\n",
batch, kv, token, head, dim, expected, actual, diff);
}
errorCount++;
passed = false;
}
totalValidElements++;
}
}
}
}
}
}
if (errorCount > 0)
{
printf("Total errors: %d out of %zu valid sparse token elements\n", errorCount, totalValidElements);
}
else
{
printf("Verification passed: all %zu valid sparse token elements match within tolerance %.2e\n",
totalValidElements, tolerance);
}
return passed;
}
void SparseKvCacheTest::performHostSparseMapping(
std::vector<half> const& originalKvCache, std::vector<half>& expectedKvCache)
{
// Host-side reference implementation of sparse KV cache mapping
// This is a naive but correct implementation for verification
// Read sparse indices from GPU memory to get the actual data being used
std::vector<int> hostSparseIndices(8 * mNumKvHeads);
TLLM_CUDA_CHECK(cudaMemcpy(hostSparseIndices.data(), mSparseKvIndicesDevice, hostSparseIndices.size() * sizeof(int),
cudaMemcpyDeviceToHost));
std::vector<int> hostSparseOffsets(mBatchSize + 1);
TLLM_CUDA_CHECK(cudaMemcpy(hostSparseOffsets.data(), mSparseKvOffsetsDevice, hostSparseOffsets.size() * sizeof(int),
cudaMemcpyDeviceToHost));
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
// Process each batch
for (int batch = 0; batch < mBatchSize; ++batch)
{
int const sparse_start_idx = hostSparseOffsets[batch];
int const sparse_end_idx = hostSparseOffsets[batch + 1];
int const num_sparse_tokens = sparse_end_idx - sparse_start_idx;
// Process each head
for (int head = 0; head < mNumKvHeads; ++head)
{
// Process both K and V cache
for (int kv = 0; kv < 2; ++kv) // 0 = K, 1 = V
{
// For each sparse token in this batch
for (int sparseTokenOffset = 0; sparseTokenOffset < num_sparse_tokens; ++sparseTokenOffset)
{
// Calculate index in new format: [num_kv_heads, num_sparse_kv_tokens]
int const global_sparse_idx = sparse_start_idx + sparseTokenOffset;
int const sparse_idx_offset = head * 8 + global_sparse_idx; // 8 is total sparse tokens
int const originalTokenIdx = hostSparseIndices[sparse_idx_offset];
int const continuousTokenIdx = sparseTokenOffset;
// Copy from original position to continuous position
for (int dim = 0; dim < mHeadSize; ++dim)
{
// Calculate indices in the flat array
// Layout: [batch][kv][token][head][dim]
size_t srcIdx = batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize)
+ kv * (mMaxSeqLen * mNumKvHeads * mHeadSize) + originalTokenIdx * (mNumKvHeads * mHeadSize)
+ head * mHeadSize + dim;
size_t dstIdx = batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize)
+ kv * (mMaxSeqLen * mNumKvHeads * mHeadSize)
+ continuousTokenIdx * (mNumKvHeads * mHeadSize) + head * mHeadSize + dim;
if (srcIdx < originalKvCache.size() && dstIdx < expectedKvCache.size())
{
expectedKvCache[dstIdx] = originalKvCache[srcIdx];
}
}
}
}
}
}
}
void SparseKvCacheTest::extractKvCacheFromGpu(std::vector<half>& kvCacheHost)
{
// Extract KV cache data from GPU KVBlockArray structure
// This is a simplified extraction for testing purposes
// Calculate total size needed
size_t totalElements = mBatchSize * mMaxSeqLen * mNumKvHeads * mHeadSize * 2;
kvCacheHost.resize(totalElements);
// For testing, we'll use a simplified approach to read back the cache
// In a real implementation, this would need to handle the block structure properly
// Calculate pool size
auto const elemSize = sizeof(half);
auto const sizePerToken = mNumKvHeads * mHeadSize * elemSize;
auto const bytesPerBlock = mTokensPerBlock * sizePerToken;
auto const totalBlocks = mBatchSize * mMaxBlocksPerSeq;
auto const poolSize = totalBlocks * bytesPerBlock * 2; // K and V
// Create temporary buffer to read entire pool
std::vector<half> poolData(poolSize / sizeof(half));
TLLM_CUDA_CHECK(cudaMemcpy(poolData.data(), mKvCachePool, poolSize, cudaMemcpyDeviceToHost));
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
// Reorganize from block structure to linear structure for comparison
// This is a simplified mapping - in reality you'd need to handle block indexing
for (int batch = 0; batch < mBatchSize; ++batch)
{
for (int token = 0; token < mMaxSeqLen; ++token)
{
for (int head = 0; head < mNumKvHeads; ++head)
{
for (int dim = 0; dim < mHeadSize; ++dim)
{
for (int kv = 0; kv < 2; ++kv)
{
// Calculate block coordinates
int blockIdx = token / mTokensPerBlock;
int tokenInBlock = token % mTokensPerBlock;
// Calculate source index in pool (simplified)
// The layout of a block in the pool is [mTokensPerBlock, mNumKvHeads, mHeadSize]
size_t block_base_pool_idx
= (size_t) (batch * mMaxBlocksPerSeq * 2 + kv * mMaxBlocksPerSeq + blockIdx)
* mTokensPerBlock * mNumKvHeads * mHeadSize;
size_t inner_block_pool_idx
= (size_t) head * mTokensPerBlock * mHeadSize + (size_t) tokenInBlock * mHeadSize + dim;
size_t poolIdx = block_base_pool_idx + inner_block_pool_idx;
// Calculate destination index in linear layout
size_t linearIdx = (size_t) batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize)
+ (size_t) kv * (mMaxSeqLen * mNumKvHeads * mHeadSize)
+ (size_t) token * (mNumKvHeads * mHeadSize) + (size_t) head * mHeadSize + dim;
if (poolIdx < poolData.size() && linearIdx < kvCacheHost.size())
{
kvCacheHost[linearIdx] = poolData[poolIdx];
}
}
}
}
}
}
}
void SparseKvCacheTest::initializeKvCacheWithPattern()
{
// Calculate pool size
auto const elemSize = sizeof(half);
auto const sizePerToken = mNumKvHeads * mHeadSize * elemSize;
auto const bytesPerBlock = mTokensPerBlock * sizePerToken;
auto const totalBlocks = mBatchSize * mMaxBlocksPerSeq;
auto const poolSize = totalBlocks * bytesPerBlock * 2; // K and V
// Create host data with recognizable pattern
std::vector<half> poolData(poolSize / sizeof(half));
for (size_t i = 0; i < poolData.size(); ++i)
{
poolData[i] = half(float(i) / 1000.0f);
}
// Copy to GPU
TLLM_CUDA_CHECK(cudaMemcpy(mKvCachePool, poolData.data(), poolSize, cudaMemcpyHostToDevice));
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
}
// Main function for standalone compilation
int main(int argc, char** argv)
{
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,155 @@
### :title Sparse Attention
### :order 5
### :section Customization
"""
This example demonstrates how to use sparse attention with TensorRT-LLM.
Supported sparse attention algorithms:
- RocketKV
Usage:
```bash
python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
```
"""
import argparse
import json
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
def read_input(input_file):
results = []
with open(input_file, 'r') as f:
for line in f:
ret = json.loads(line)
results.append(ret)
return results
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_path',
type=str,
default=
"/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
)
parser.add_argument(
'--input_file',
type=str,
default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
)
# Build config
parser.add_argument('--algo',
type=str,
default='ROCKETKV',
choices=['ROCKETKV'])
parser.add_argument('--attention_backend',
type=str,
default='TRTLLM',
choices=['VANILLA', 'TRTLLM'])
parser.add_argument('--window_size',
type=int,
default=32,
help="The window size for RocketKV.")
parser.add_argument('--kernel_size',
type=int,
default=63,
help="The kernel size for RocketKV.")
parser.add_argument('--prompt_budget',
type=int,
default=2048,
help="The prompt budget for RocketKV.")
parser.add_argument("--max_seq_len",
type=int,
default=8192,
help="The maximum sequence length.")
parser.add_argument("--max_batch_size",
type=int,
default=256,
help="The maximum batch size.")
parser.add_argument("--max_new_tokens",
type=int,
default=128,
help="The maximum new tokens.")
parser.add_argument(
"--max_num_tokens",
type=int,
default=8192,
help=
"The maximum total tokens (context + generation) across all sequences in a batch."
)
parser.add_argument('--tensor_parallel_size', type=int, default=1)
# KV cache
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
parser.add_argument("--kv_cache_fraction", type=float, default=None)
parser.add_argument('--num_samples', type=int, default=10)
args = parser.parse_args()
return args
def run_RocketKV(args):
data = read_input(args.input_file)
num_samples = args.num_samples if args.num_samples is not None else len(
data)
data = data[:num_samples]
kv_cache_config = KvCacheConfig(
enable_block_reuse=
False, # sparse attention does not support kv cache reuse now
free_gpu_memory_fraction=args.kv_cache_fraction,
dtype=args.kv_cache_dtype,
)
sparse_attention_config = RocketSparseAttentionConfig(
window_size=args.window_size,
kernel_size=args.kernel_size,
prompt_budget=args.prompt_budget,
)
llm = LLM(
model=args.model_path,
backend='pytorch',
kv_cache_config=kv_cache_config,
attn_backend=args.attention_backend,
sparse_attention_config=sparse_attention_config,
max_batch_size=args.max_batch_size,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
tensor_parallel_size=args.tensor_parallel_size,
cuda_graph_config=
None, # sparse attention does not support cuda graph now
)
prompts = []
reference = []
for sample in data:
prompts.append(
{'prompt': sample['input_context'] + sample['input_query']})
reference.append(sample['outputs'])
sampling_params = SamplingParams(add_special_tokens=False,
max_tokens=args.max_new_tokens,
temperature=0.8,
top_p=0.95)
outputs = llm.generate(prompts, sampling_params)
for idx, output in enumerate(outputs):
print(
f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
)
def main():
args = parse_arguments()
if args.algo == 'ROCKETKV':
run_RocketKV(args)
else:
raise ValueError(f"Invalid algorithm: {args.algo}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,170 @@
# LongBench Evaluation with TensorRT-LLM and Sparse Attention
This directory contains evaluation scripts for both LongBench v1 and LongBench v2 datasets using TensorRT-LLM backend.
## Environment Setup
### 1. Clone LongBench Repository
First, clone the LongBench repository which contains the datasets and evaluation utilities:
```bash
git clone https://github.com/THUDM/LongBench.git
```
### 2. Install Requirements
Install the required dependencies:
```bash
pip install -r requirements.txt
```
### 3. Directory Structure
After cloning, your directory structure should look like:
```text
sparse_attention/
├── eval_longbench_v1.py # LongBench v1 evaluation script
├── eval_longbench_v2.py # LongBench v2 evaluation script
├── README.md # This file
└── LongBench/ # Cloned LongBench repository
├── LongBench/ # LongBench v1 data and configs
│ ├── config/
│ └── ...
├── config/ # LongBench v2 configs
├── ...
└── requirements.txt
```
## Scripts Overview
### 1. eval_longbench_v1.py
This script evaluates models on the **LongBench v1** dataset, which includes multiple specific tasks like narrativeqa, qasper, multifieldqa, etc. Key features:
- **Dataset**: LongBench v1 with task-specific evaluation
- **Tasks**: Support for 20+ different long-context tasks
- **Prompts**: Task-specific prompts from LongBench v1 configuration
- **Metrics**: Task-specific metrics (F1, ROUGE, classification scores, etc.)
- **Output**: Task-level results with comprehensive summary statistics
### 2. eval_longbench_v2.py
This script evaluates models on the **LongBench v2** dataset, which is a standardized multiple-choice format. Key features:
- **Dataset**: LongBench v2 with unified multiple-choice format
- **Format**: All questions are A/B/C/D multiple choice
- **Context Length**: 8K to 2M words (majority under 128K)
- **Difficulty**: Easy/Hard categorization
- **Length**: Short/Medium/Long categorization
- **Domains**: Various domains (single-doc QA, multi-doc QA, code, etc.)
- **CoT Support**: Chain-of-Thought reasoning support
- **Metrics**: Accuracy with breakdowns by difficulty, length, and domain
## Usage Examples
### LongBench v1 Evaluation
#### Basic Usage (Standard Attention)
```bash
python eval_longbench_v1.py \
--model_path "/path/to/your/model" \
--longbench_path ./LongBench \
--output_dir results/v1_vanilla \
--attention_backend VANILLA \
--backend pytorch
```
#### Specific tasks With Sparse Attention (RocketKV)
```bash
python eval_longbench_v1.py \
--model_path "/path/to/your/model" \
--longbench_path ./LongBench \
--dataset narrativeqa qasper \
--output_dir results/v1_rocket \
--attention_backend VANILLA \
--backend pytorch \
--rocket_sparse
```
### LongBench v2 Evaluation
#### Basic Usage (Standard Attention)
```bash
python eval_longbench_v2.py \
--model_path "/path/to/your/model" \
--longbench_path ./LongBench \
--output_dir results/v2_vanilla
```
#### With Chain-of-Thought Reasoning
```bash
python eval_longbench_v2.py \
--model_path "/path/to/your/model" \
--longbench_path ./LongBench \
--output_dir results/v2_cot \
--cot
```
#### Filter by Difficulty/Length/Domain
```bash
# Easy questions only
python eval_longbench_v2.py \
--model_path "/path/to/your/model" \
--longbench_path ./LongBench \
--output_dir results/v2_easy \
--difficulty easy
# Long context only
python eval_longbench_v2.py \
--model_path "/path/to/your/model" \
--longbench_path ./LongBench \
--output_dir results/v2_long \
--length long
# Specific domain
python eval_longbench_v2.py \
--model_path "/path/to/your/model" \
--longbench_path ./LongBench \
--output_dir results/v2_code \
--domain "Code"
```
#### Limited Sample Evaluation (for testing)
```bash
python eval_longbench_v2.py \
--model_path "/path/to/your/model" \
--longbench_path ./LongBench \
--output_dir results/v2_test \
--num_samples 10
```
## Output Structure
### LongBench v1 Output
```text
results/v1_experiment/
├── config.json # Experiment configuration
├── overall_summary.json # Overall experiment summary
├── narrativeqa/
│ ├── narrativeqa_results.jsonl # Detailed results
│ ├── narrativeqa_summary.json # Task summary
│ └── pred/
│ └── narrativeqa.jsonl # Predictions in LongBench format
├── qasper/
│ └── ...
└── ...
```
### LongBench v2 Output
```text
results/v2_experiment/
├── config.json # Experiment configuration
├── summary.json # Evaluation summary with metrics
├── longbench_v2_results.jsonl # Detailed results
└── predictions.jsonl # Predictions in LongBench v2 format
```

View File

@ -0,0 +1,797 @@
#!/usr/bin/env python3
"""
LongBench v1 evaluation script with TensorRT-LLM and sparse attention.
Usage:
python longbench_rocket_eval.py --dataset narrativeqa --model_path /path/to/model --longbench_path ./LongBench --output_dir results/
# Run all LongBench tasks
python longbench_rocket_eval.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --token_budget 2048 --rocket_sparse
"""
import argparse
import json
import os
import sys
import time
from datetime import datetime
from typing import Any, Dict, List, Tuple
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
# Add tensorrt_llm imports
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
from tensorrt_llm.logger import logger
LONGBENCH_DATASETS = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
"dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
"passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]
# Task categorization
TASK_DATASETS = {
'single_doc_qa':
["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh"],
'multi_doc_qa': ["hotpotqa", "2wikimqa", "musique", "dureader"],
'summarization': ["gov_report", "qmsum", "multi_news", "vcsum"],
'few_shots': ["trec", "triviaqa", "samsum", "lsht"],
'synthetic':
["passage_count", "passage_retrieval_en", "passage_retrieval_zh"],
'code': ["lcc", "repobench-p"]
}
# Chat templates mapping
CHAT_TEMPLATES = {
"llama3.1-8b-instruct": "llama3",
"llama3-8b-instruct": "llama3",
"mistral-7b-instruct-v0.2": "mistral",
"longchat-7b-v1.5-32k": "vicuna"
}
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="LongBench evaluation with TensorRT-LLM and RocketKV")
# Model and data arguments
parser.add_argument('--model_path',
type=str,
required=True,
help='Path to model (HF model name or local path)')
parser.add_argument('--dataset',
type=str,
nargs='+',
choices=LONGBENCH_DATASETS,
help='LongBench datasets to evaluate on')
parser.add_argument('--run_all_tasks',
action='store_true',
help='Run evaluation on all LongBench tasks')
parser.add_argument('--longbench_path',
type=str,
default='./LongBench',
help='Path to LongBench directory')
# Output arguments
parser.add_argument('--output_dir',
type=str,
required=True,
help='Directory to save results')
parser.add_argument('--exp_name',
type=str,
default=None,
help='Experiment name (auto-generated if not provided)')
# Model configuration
parser.add_argument('--attention_backend',
type=str,
default='VANILLA',
choices=['VANILLA', 'TRTLLM', 'FLASHINFER'],
help='Attention backend to use')
parser.add_argument('--backend',
type=str,
default='pytorch',
choices=['pytorch', 'tensorrt'],
help='LLM backend to use')
parser.add_argument('--chat_template',
type=str,
default='auto',
help='Chat template to use (auto-detect if "auto")')
# Sequence and batch configuration
parser.add_argument('--max_seq_len',
type=int,
default=133120,
help='Maximum sequence length')
parser.add_argument('--max_batch_size',
type=int,
default=1,
help='Maximum batch size')
parser.add_argument('--max_new_tokens',
type=int,
default=256,
help='Maximum new tokens to generate')
parser.add_argument(
'--max_num_tokens',
type=int,
default=133120,
help='Maximum total tokens across all sequences in a batch')
parser.add_argument('--tensor_parallel_size',
type=int,
default=1,
help='Tensor parallel size')
# RocketKV configuration
parser.add_argument('--rocket_sparse',
action='store_true',
help='Use rocket sparse attention')
parser.add_argument('--token_budget',
type=int,
default=2048,
help='Token budget for RocketKV (prompt_budget)')
parser.add_argument('--window_size',
type=int,
default=32,
help='Window size for RocketKV')
parser.add_argument('--kernel_size',
type=int,
default=63,
help='Kernel size for RocketKV')
parser.add_argument('--topr',
type=int,
default=90,
help='Top-r for RocketKV')
# KV cache configuration
parser.add_argument('--kv_cache_dtype',
type=str,
default='auto',
help='KV cache data type')
parser.add_argument('--kv_cache_fraction',
type=float,
default=0.7,
help='Fraction of GPU memory for KV cache')
# Evaluation parameters
parser.add_argument('--num_samples',
type=int,
default=None,
help='Number of samples to evaluate (None for all)')
parser.add_argument('--start_idx',
type=int,
default=0,
help='Start index for evaluation')
# System arguments
parser.add_argument('--log_level',
type=str,
default='info',
choices=['debug', 'info', 'warning', 'error'],
help='Logging level')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
args = parser.parse_args()
# Validation
if not args.run_all_tasks and not args.dataset:
parser.error("Must specify either --dataset or --run_all_tasks")
return args
def setup_longbench_imports(longbench_path: str):
"""Add LongBench to Python path and import required modules."""
longbench_dir = os.path.join(longbench_path, "LongBench") # for v1
if not os.path.exists(longbench_dir):
raise FileNotFoundError(
f"LongBench directory not found: {longbench_dir}")
# Add to path
if longbench_dir not in sys.path:
sys.path.insert(0, longbench_dir)
def load_longbench_config(
longbench_path: str) -> Tuple[Dict[str, str], Dict[str, int]]:
"""Load LongBench configuration files."""
config_dir = os.path.join(longbench_path, "LongBench", "config")
# Load dataset2prompt.json
prompt_file = os.path.join(config_dir, "dataset2prompt.json")
with open(prompt_file, 'r', encoding='utf-8') as f:
dataset2prompt = json.load(f)
# Load dataset2maxlen.json
maxlen_file = os.path.join(config_dir, "dataset2maxlen.json")
with open(maxlen_file, 'r', encoding='utf-8') as f:
dataset2maxlen = json.load(f)
return dataset2prompt, dataset2maxlen
# LongBench's build_chat function (simplified version)
def build_chat(tokenizer, prompt, chat_template):
"""Build chat prompt following LongBench's approach."""
if chat_template == "vicuna" or chat_template == "longchat":
try:
from fastchat.model import get_conversation_template
conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
except ImportError:
# Fallback if fastchat is not available
prompt = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: {prompt}\nASSISTANT:"
elif chat_template == "llama2":
prompt = f"[INST]{prompt}[/INST]"
elif chat_template == "llama3":
prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
elif chat_template == "mistral":
prompt = f"[INST] {prompt} [/INST]"
# For other templates or "none", return prompt as-is
return prompt
def determine_chat_template(model_path: str, chat_template: str) -> str:
"""Determine chat template based on model path."""
if chat_template != 'auto':
return chat_template
model_path_lower = model_path.lower()
for model_key, template in CHAT_TEMPLATES.items():
if model_key.replace('-', '').replace('.',
'') in model_path_lower.replace(
'-', '').replace('.', ''):
return template
# Default fallback
if 'llama' in model_path_lower:
return 'llama3'
elif 'mistral' in model_path_lower:
return 'mistral'
else:
return 'none' # No special formatting
def post_process(pred: str, chat_template: str, dataset: str) -> str:
"""Post-process prediction following LongBench's approach."""
pred = pred.split("</s")[0].strip()
if chat_template == "qwen":
pred = pred.split("<|im_end|>")[0]
elif "llama2" in chat_template.lower():
pred = (pred.split("(Document")[0].split("\n\nQuestion")[0].split(
"\n\nAnswer")[0].split("[INST]")[0].split("[/INST]")[0].split(
"(Passage")[0].strip())
if dataset == "samsum":
pred = pred.split("\n")[0].strip()
return pred
def format_prompt_style(sample: Dict[str, Any], instruction: str,
chat_template: str, dataset: str, tokenizer) -> str:
"""Format prompt following LongBench's approach."""
# First format the instruction using the sample data
prompt = instruction.format(**sample)
if dataset not in [
"trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"
]:
prompt = build_chat(tokenizer, prompt, chat_template)
return prompt
def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
"""Initialize LLM and tokenizer."""
logger.info(f"Initializing LLM with model: {args.model_path}")
try:
# Configure KV cache
kv_cache_config = KvCacheConfig(
enable_block_reuse=False, # RocketKV doesn't support KV cache reuse
)
if args.rocket_sparse:
# Configure RocketKV sparse attention
sparse_attention_config = RocketSparseAttentionConfig(
window_size=args.window_size,
kernel_size=args.kernel_size,
prompt_budget=args.token_budget,
topr=args.topr,
)
logger.info(f"Using RocketKV sparse attention")
else:
sparse_attention_config = None
logger.info("Using standard attention")
# Initialize LLM
llm = LLM(
model=args.model_path,
backend=args.backend,
kv_cache_config=kv_cache_config,
max_batch_size=args.max_batch_size,
attn_backend=args.attention_backend,
sparse_attention_config=sparse_attention_config,
tensor_parallel_size=args.tensor_parallel_size,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
cuda_graph_config=None,
torch_compile_config=None,
)
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
logger.info("LLM and tokenizer initialized successfully")
return llm, tokenizer
except Exception as e:
logger.error(f"Failed to initialize LLM: {e}")
raise e
def evaluate_single_dataset(
dataset: str, llm: LLM, tokenizer: AutoTokenizer,
args: argparse.Namespace) -> Tuple[List[Dict], float]:
"""Evaluate a single dataset."""
setup_longbench_imports(args.longbench_path)
dataset2prompt, dataset2maxlen = load_longbench_config(args.longbench_path)
# Load dataset
logger.info(f"Loading dataset: {dataset}")
data = [
data_sample for data_sample in load_dataset(
'THUDM/LongBench', dataset, split='test', trust_remote_code=True)
]
# Apply data filtering
if args.num_samples:
end_idx = min(args.start_idx + args.num_samples, len(data))
filtered_data = data[args.start_idx:end_idx]
else:
filtered_data = data[args.start_idx:]
logger.info(f"Dataset {dataset}: {len(filtered_data)} samples to evaluate")
# Determine chat template
chat_template = determine_chat_template(args.model_path, args.chat_template)
logger.info(f"Using chat template: {chat_template}")
# Create sampling parameters
max_new_tokens = dataset2maxlen[dataset]
prompt_format = dataset2prompt[dataset]
# Set up extra end token ids
extra_end_token_ids = []
if chat_template == "llama3":
eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
extra_end_token_ids.append(eot_id)
logger.info(f"Added llama3 end token: {eot_id}")
if chat_template == "qwen":
im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
extra_end_token_ids.append(im_end_id)
logger.info(f"Added qwen end token: {im_end_id}")
if dataset == "samsum":
newline_id = tokenizer.encode("\n", add_special_tokens=False)[-1]
extra_end_token_ids.append(newline_id)
logger.info(f"Added samsum newline token: {newline_id}")
# Prepare prompts
prompts = []
for sample in filtered_data:
formatted_prompt = format_prompt_style(sample, prompt_format,
chat_template, dataset,
tokenizer)
prompts.append(formatted_prompt)
if len(prompts) == 0:
logger.warning(f"No prompts to evaluate for dataset {dataset}")
return [], 0.0
# Run inference
logger.info(f"Starting inference for {len(prompts)} samples...")
start_time = time.time()
sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0.8,
top_p=0.95,
stop_token_ids=extra_end_token_ids if extra_end_token_ids else None,
)
outputs = llm.generate(prompts, sampling_params)
inference_time = time.time() - start_time
logger.info(
f"Inference completed in {inference_time:.2f} seconds, average time per sample: {inference_time/len(prompts):.3f} seconds"
)
# Prepare results
results = []
for i, (sample, output) in enumerate(zip(filtered_data, outputs)):
prediction = output.outputs[0].text.strip()
processed_prediction = post_process(prediction, chat_template, dataset)
result = {
'sample_id': args.start_idx + i,
'input': sample.get('input', ''),
'context': sample.get('context', ''),
'answers': sample.get('answers', []),
'all_classes': sample.get('all_classes', []),
'prediction': processed_prediction,
'raw_prediction': prediction,
'prompt_length': len(output.prompt_token_ids),
'output_length': len(output.outputs[0].token_ids),
'inference_time': getattr(output, 'inference_time', None),
'length': sample.get('length', 0)
}
results.append(result)
return results, inference_time
def calculate_metrics(
dataset: str, predictions: List[str], answers_list: List[List[str]],
all_classes_list: List[List[str]],
longbench_path: str) -> Tuple[Dict[str, float], List[float]]:
"""Calculate evaluation metrics for a dataset following LongBench's implementation."""
# Setup LongBench imports
setup_longbench_imports(longbench_path)
# Import LongBench metrics
from metrics import (classification_score, code_sim_score, count_score,
qa_f1_score, qa_f1_zh_score, retrieval_score,
retrieval_zh_score, rouge_score, rouge_zh_score)
# Mapping of datasets to their metric functions (from LongBench)
dataset2metric = {
"narrativeqa": qa_f1_score,
"qasper": qa_f1_score,
"multifieldqa_en": qa_f1_score,
"multifieldqa_zh": qa_f1_zh_score,
"hotpotqa": qa_f1_score,
"2wikimqa": qa_f1_score,
"musique": qa_f1_score,
"dureader": rouge_zh_score,
"gov_report": rouge_score,
"qmsum": rouge_score,
"multi_news": rouge_score,
"vcsum": rouge_zh_score,
"trec": classification_score,
"triviaqa": qa_f1_score,
"samsum": rouge_score,
"lsht": classification_score,
"passage_retrieval_en": retrieval_score,
"passage_count": count_score,
"passage_retrieval_zh": retrieval_zh_score,
"lcc": code_sim_score,
"repobench-p": code_sim_score,
}
if dataset not in dataset2metric:
# Fallback to simple exact match with cleaning
total_score = 0
scores = []
for pred, answers in zip(predictions, answers_list):
cleaned_pred = pred.lstrip('\n').split('\n')[0].strip()
score = max([
1.0 if cleaned_pred.lower() == ans.strip().lower() else 0.0
for ans in answers
])
scores.append(score)
total_score += score
return {
"exact_match": round(100 * total_score / len(predictions), 2)
}, scores
metric_func = dataset2metric[dataset]
total_score = 0.0
scores = []
# Follow LongBench's scorer function exactly
for pred, ground_truths, all_classes in zip(predictions, answers_list,
all_classes_list):
score = 0.0
# Apply the same prediction cleaning as LongBench
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
pred = pred.lstrip('\n').split('\n')[0]
# For code datasets, apply additional cleaning
if dataset in ["lcc", "repobench-p"]:
# This cleaning is done inside code_sim_score, but let's also apply it here for consistency
all_lines = pred.lstrip('\n').split('\n')
for line in all_lines:
if ('`' not in line) and ('#' not in line) and ('//'
not in line):
pred = line
break
# Calculate max score across all reference answers (exactly as in LongBench)
for ground_truth in ground_truths:
score = max(
score, metric_func(pred, ground_truth, all_classes=all_classes))
scores.append(score)
total_score += score
final_score = round(100 * total_score / len(predictions), 2)
return {metric_func.__name__: final_score}, scores
def calculate_task_summary(all_results: Dict[str, Dict]) -> Dict[str, Any]:
"""Calculate task-level summary statistics following long_bench_tasks_summary.py approach."""
logger.info("Calculating task-level summary statistics...")
summary = {}
ind_dataset_result = {}
task_ave_result = {}
NA_flag = False
# Get individual dataset results
for dataset in LONGBENCH_DATASETS:
if dataset in all_results and 'metrics' in all_results[dataset]:
metrics = all_results[dataset]['metrics']
# Get the first (and usually only) metric value
if metrics:
metric_key = list(metrics.keys())[0]
val = metrics[metric_key]
ind_dataset_result[dataset] = val
else:
ind_dataset_result[dataset] = 'N/A'
NA_flag = True
else:
ind_dataset_result[dataset] = 'N/A'
NA_flag = True
summary['individual_dataset_result'] = ind_dataset_result
# Calculate task-average results
for task, datasets in TASK_DATASETS.items():
task_NA_flag = False
task_ave_result[task] = 0
valid_count = 0
for dataset in datasets:
if dataset in ind_dataset_result and ind_dataset_result[
dataset] != 'N/A':
task_ave_result[task] += ind_dataset_result[dataset]
valid_count += 1
else:
task_NA_flag = True
if task_NA_flag or valid_count == 0:
task_ave_result[task] = 'N/A'
else:
task_ave_result[task] = np.round(task_ave_result[task] /
valid_count,
decimals=2)
summary['task_average_result'] = task_ave_result
# Calculate overall LongBench average result (excluding passage_count as in original)
if NA_flag:
summary['LB_average_result'] = 'N/A'
else:
average_result = 0
valid_count = 0
for dataset in LONGBENCH_DATASETS:
if dataset != 'passage_count' and dataset in ind_dataset_result:
if ind_dataset_result[dataset] != 'N/A':
average_result += ind_dataset_result[dataset]
valid_count += 1
if valid_count > 0:
summary['LB_average_result'] = np.round(average_result /
valid_count,
decimals=2)
else:
summary['LB_average_result'] = 'N/A'
# Log summary statistics
logger.info("Task Summary Results:")
logger.info(f"Overall LongBench Average: {summary['LB_average_result']}")
for task, score in task_ave_result.items():
logger.info(f"{task}: {score}")
return summary
def save_results(results: List[Dict], dataset: str, args: argparse.Namespace,
inference_time: float, output_dir: str):
"""Save evaluation results in format compatible with LongBench."""
task_output_dir = os.path.join(output_dir, dataset)
os.makedirs(task_output_dir, exist_ok=True)
# Extract predictions, answers, and all_classes for evaluation
predictions = [r['prediction'] for r in results]
answers_list = [r['answers'] for r in results]
all_classes_list = [r.get('all_classes', []) for r in results]
# Calculate metrics
processed_results, scores = calculate_metrics(dataset, predictions,
answers_list,
all_classes_list,
args.longbench_path)
logger.info(f"Evaluation metrics: {processed_results}")
# Save detailed results for manual inspection
results_file = os.path.join(task_output_dir, f"{dataset}_results.jsonl")
with open(results_file, 'w', encoding='utf-8') as f:
for result in results:
json.dump(result, f, ensure_ascii=False)
f.write('\n')
# Save prediction results in LongBench format for evaluation
pred_dir = os.path.join(task_output_dir, "pred")
os.makedirs(pred_dir, exist_ok=True)
pred_file = os.path.join(pred_dir, f"{dataset}.jsonl")
with open(pred_file, 'w', encoding='utf-8') as f:
for idx, result in enumerate(results):
pred_data = {
"pred": result['prediction'],
"answers": result['answers'],
"all_classes": result.get('all_classes', []),
"length": result.get('length', 0),
"score": scores[idx]
}
json.dump(pred_data, f, ensure_ascii=False)
f.write('\n')
# Create summary following LongBench format
config = {
'pipeline_params': {
'model_name': args.model_path,
'method': args.attention_backend,
'token_budget': args.token_budget,
'max_seq_len': args.max_seq_len,
'max_new_tokens': args.max_new_tokens,
'window_size': args.window_size,
'kernel_size': args.kernel_size,
'num_processes': 1, # Single process
'devices': "0" # Single device
},
'eval_params': {
'dataset': dataset,
'num_samples': len(results)
},
'eval_results': {
'processed_results': processed_results
},
'management': {
'output_folder_dir': task_output_dir,
'exp_desc':
f'{dataset}_{os.path.basename(args.model_path)}_{args.attention_backend}_{args.token_budget}',
'total_inference_time': inference_time,
'avg_inference_time': inference_time / len(results),
'evaluation_timestamp': datetime.now().isoformat()
}
}
# Save summary
summary_file = os.path.join(task_output_dir, f"{dataset}_summary.json")
with open(summary_file, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
logger.info(f"The results of {dataset} are saved to {task_output_dir}")
return processed_results
def main():
"""Main evaluation function."""
args = parse_arguments()
logger.set_level(args.log_level)
# Setup experiment name
if not args.exp_name:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_name = os.path.basename(args.model_path).replace('/', '_')
args.exp_name = f"longbench_{model_name}_{timestamp}"
output_dir = os.path.join(args.output_dir, args.exp_name)
logger.info(
"=========== LongBench Evaluation with TensorRT-LLM ===========")
os.makedirs(output_dir, exist_ok=True)
# Save configuration
config_file = os.path.join(output_dir, "config.json")
with open(config_file, 'w') as f:
json.dump(vars(args), f, indent=2)
logger.info(f"Configuration saved to {config_file}")
# Determine datasets to evaluate
if args.run_all_tasks:
datasets = LONGBENCH_DATASETS
logger.info(f"Running evaluation on full LongBench datasets")
else:
datasets = args.dataset
logger.info(f"Running evaluation on datasets: {args.dataset}")
# Initialize LLM and tokenizer
llm, tokenizer = initialize_llm(args)
# Process datasets sequentially
all_results = {}
for dataset_idx, dataset in enumerate(datasets):
logger.info(f"{'='*30}")
logger.info(
f"Processing dataset {dataset_idx+1}/{len(datasets)}: {dataset}...")
# Evaluate the dataset
results, inference_time = evaluate_single_dataset(
dataset, llm, tokenizer, args)
# Save results and get metrics
processed_results = save_results(results, dataset, args, inference_time,
output_dir)
all_results[dataset] = {
'num_samples': len(results),
'inference_time': inference_time,
'output_dir': output_dir,
'metrics': processed_results
}
logger.info(f"Dataset {dataset} completed successfully")
total_time = sum(all_results[d]['inference_time'] for d in all_results)
# Calculate task-level summary
task_summary = calculate_task_summary(all_results)
# Save overall summary with task statistics
overall_summary = {
'experiment_name':
args.exp_name,
'total_evaluation_time':
total_time,
'evaluated_datasets':
list(all_results.keys()),
'successful_datasets':
[d for d, r in all_results.items() if 'error' not in r],
'failed_datasets': [d for d, r in all_results.items() if 'error' in r],
'results_by_dataset':
all_results,
'task_summary':
task_summary, # Add task-level summary
'configuration':
vars(args)
}
overall_summary_file = os.path.join(output_dir, "overall_summary.json")
with open(overall_summary_file, 'w') as f:
json.dump(overall_summary, f, indent=2)
logger.info(f"\n{'-'*80}")
logger.info(
f"Evaluation completed. Overall summary saved to: {overall_summary_file}"
)
logger.info(f"Total time: {total_time:.2f} seconds")
# Print final summary
if task_summary['LB_average_result'] != 'N/A':
logger.info(f"FINAL RESULTS:")
logger.info(
f"LongBench Overall Average: {task_summary['LB_average_result']}")
logger.info(f"Task-level results:")
for task, score in task_summary['task_average_result'].items():
logger.info(f" {task}: {score}")
if overall_summary['failed_datasets']:
logger.warning(f"Failed datasets: {overall_summary['failed_datasets']}")
if __name__ == '__main__':
main()

View File

@ -0,0 +1,771 @@
#!/usr/bin/env python3
"""
LongBench v2 evaluation script with TensorRT-LLM and sparse attention.
Usage:
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/
# Run all LongBench v2 samples
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/
# Enable CoT reasoning
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --cot
# Run with different difficulty levels
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --difficulty easy
"""
import argparse
import json
import os
import re
import time
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from datasets import load_dataset
from transformers import AutoTokenizer
# Add tensorrt_llm imports
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
from tensorrt_llm.logger import logger
# Chat templates mapping
CHAT_TEMPLATES = {
"llama3.1-8b-instruct": "llama3",
"llama3-8b-instruct": "llama3",
"mistral-7b-instruct-v0.2": "mistral",
"longchat-7b-v1.5-32k": "vicuna"
}
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="LongBench v2 evaluation with TensorRT-LLM and RocketKV")
# Model and data arguments
parser.add_argument('--model_path',
type=str,
required=True,
help='Path to model (HF model name or local path)')
parser.add_argument('--longbench_path',
type=str,
default='./LongBench',
help='Path to LongBench directory')
# Output arguments
parser.add_argument('--output_dir',
type=str,
required=True,
help='Directory to save results')
parser.add_argument('--exp_name',
type=str,
default=None,
help='Experiment name (auto-generated if not provided)')
# Model configuration
parser.add_argument('--attention_backend',
type=str,
default='VANILLA',
choices=['VANILLA', 'TRTLLM', 'FLASHINFER'],
help='Attention backend to use')
parser.add_argument('--backend',
type=str,
default='pytorch',
choices=['pytorch', 'tensorrt'],
help='LLM backend to use')
parser.add_argument('--chat_template',
type=str,
default='auto',
help='Chat template to use (auto-detect if "auto")')
# Sequence and batch configuration
parser.add_argument('--max_seq_len',
type=int,
default=133120,
help='Maximum sequence length')
parser.add_argument('--max_batch_size',
type=int,
default=1,
help='Maximum batch size')
parser.add_argument('--max_new_tokens',
type=int,
default=256,
help='Maximum new tokens to generate')
parser.add_argument(
'--max_num_tokens',
type=int,
default=133120,
help='Maximum total tokens across all sequences in a batch')
parser.add_argument('--tensor_parallel_size',
type=int,
default=1,
help='Tensor parallel size')
# RocketKV configuration
parser.add_argument('--rocket_sparse',
action='store_true',
help='Use rocket sparse attention')
parser.add_argument('--token_budget',
type=int,
default=2048,
help='Token budget for RocketKV (prompt_budget)')
parser.add_argument('--window_size',
type=int,
default=32,
help='Window size for RocketKV')
parser.add_argument('--kernel_size',
type=int,
default=63,
help='Kernel size for RocketKV')
parser.add_argument('--topr',
type=int,
default=90,
help='Top-r for RocketKV')
# KV cache configuration
parser.add_argument('--kv_cache_dtype',
type=str,
default='auto',
help='KV cache data type')
parser.add_argument('--kv_cache_fraction',
type=float,
default=0.7,
help='Fraction of GPU memory for KV cache')
# LongBench v2 specific arguments
parser.add_argument('--cot',
action='store_true',
help='Enable Chain-of-Thought reasoning')
parser.add_argument('--no_context',
action='store_true',
help='Test without long context (pure memorization)')
parser.add_argument('--rag',
type=int,
default=0,
help='Use top-N retrieved contexts (0 to disable)')
# Evaluation parameters
parser.add_argument('--num_samples',
type=int,
default=None,
help='Number of samples to evaluate (None for all)')
parser.add_argument('--start_idx',
type=int,
default=0,
help='Start index for evaluation')
parser.add_argument('--difficulty',
type=str,
choices=['easy', 'hard'],
default=None,
help='Filter by difficulty level')
parser.add_argument('--length',
type=str,
choices=['short', 'medium', 'long'],
default=None,
help='Filter by length category')
parser.add_argument('--domain',
type=str,
default=None,
help='Filter by domain')
parser.add_argument('--max_len',
type=int,
default=120000,
help='Maximum prompt length in tokens for truncation')
# System arguments
parser.add_argument('--log_level',
type=str,
default='info',
choices=['debug', 'info', 'warning', 'error'],
help='Logging level')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
return parser.parse_args()
def load_longbench_v2_config(longbench_path: str) -> Dict[str, Any]:
"""Load LongBench v2 configuration files."""
config_dir = os.path.join(longbench_path, "config")
# Load model2maxlen.json for v2
maxlen_file = os.path.join(config_dir, "model2maxlen.json")
with open(maxlen_file, 'r', encoding='utf-8') as f:
model2maxlen = json.load(f)
# Load prompt templates
prompts_dir = os.path.join(longbench_path, "prompts")
templates = {}
template_files = {
'0shot': '0shot.txt',
'0shot_cot': '0shot_cot.txt',
'0shot_cot_ans': '0shot_cot_ans.txt',
'0shot_no_context': '0shot_no_context.txt',
'0shot_rag': '0shot_rag.txt'
}
for template_name, filename in template_files.items():
template_path = os.path.join(prompts_dir, filename)
if os.path.exists(template_path):
with open(template_path, 'r', encoding='utf-8') as f:
templates[template_name] = f.read()
return {'model2maxlen': model2maxlen, 'templates': templates}
def build_chat(tokenizer, prompt, chat_template):
"""Build chat prompt following LongBench's approach."""
if chat_template == "vicuna" or chat_template == "longchat":
try:
from fastchat.model import get_conversation_template
conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
except ImportError:
# Fallback if fastchat is not available
prompt = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: {prompt}\nASSISTANT:"
elif chat_template == "llama2":
prompt = f"[INST]{prompt}[/INST]"
elif chat_template == "llama3":
prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
elif chat_template == "mistral":
prompt = f"[INST] {prompt} [/INST]"
# For other templates or "none", return prompt as-is
return prompt
def determine_chat_template(model_path: str, chat_template: str) -> str:
"""Determine chat template based on model path."""
if chat_template != 'auto':
return chat_template
model_path_lower = model_path.lower()
for model_key, template in CHAT_TEMPLATES.items():
if model_key.replace('-', '').replace('.',
'') in model_path_lower.replace(
'-', '').replace('.', ''):
return template
# Default fallback
if 'llama' in model_path_lower:
return 'llama3'
elif 'mistral' in model_path_lower:
return 'mistral'
else:
return 'none' # No special formatting
def extract_answer(response: str) -> Optional[str]:
"""Extract answer from response following LongBench v2's approach."""
response = response.replace('*', '')
# Try to extract answer in format "The correct answer is (X)"
match = re.search(r'The correct answer is \(([A-D])\)', response)
if match:
return match.group(1)
# Try to extract answer in format "The correct answer is X"
match = re.search(r'The correct answer is ([A-D])', response)
if match:
return match.group(1)
# Try to extract any single letter A-D
match = re.search(r'\b([A-D])\b', response)
if match:
return match.group(1)
return None
def post_process(pred: str, chat_template: str) -> str:
"""Post-process prediction following LongBench v2's approach."""
pred = pred.split("</s")[0].strip()
if chat_template == "qwen":
pred = pred.split("<|im_end|>")[0]
elif "llama2" in chat_template.lower():
pred = (pred.split("(Document")[0].split("\n\nQuestion")[0].split(
"\n\nAnswer")[0].split("[INST]")[0].split("[/INST]")[0].split(
"(Passage")[0].strip())
return pred
def truncate_prompt(prompt: str, tokenizer: AutoTokenizer, max_len: int) -> str:
"""Truncate prompt following LongBench v2's approach."""
# Encode the prompt using the tokenizer
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
# If prompt exceeds max_len, truncate by taking first half and last half
if len(input_ids) > max_len:
half = max_len // 2
truncated_ids = input_ids[:half] + input_ids[-half:]
# Decode back to text
prompt = tokenizer.decode(truncated_ids, skip_special_tokens=True)
return prompt
def format_prompt(sample: Dict[str, Any], template: str,
args: argparse.Namespace) -> str:
"""Format prompt for LongBench v2."""
context = sample['context']
# Handle RAG mode
if args.rag > 0 and 'retrieved_context' in sample:
retrieved = sample["retrieved_context"][:args.rag]
retrieved = sorted(retrieved, key=lambda x: x.get('c_idx', 0))
context = '\n\n'.join([
f"Retrieved chunk {idx+1}: {x['content']}"
for idx, x in enumerate(retrieved)
])
# Handle no context mode
if args.no_context:
context = ""
# Format the prompt using the template
prompt = template.replace('$DOC$', context.strip())
prompt = prompt.replace('$Q$', sample['question'].strip())
prompt = prompt.replace('$C_A$', sample['choice_A'].strip())
prompt = prompt.replace('$C_B$', sample['choice_B'].strip())
prompt = prompt.replace('$C_C$', sample['choice_C'].strip())
prompt = prompt.replace('$C_D$', sample['choice_D'].strip())
return prompt
def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
"""Initialize LLM and tokenizer."""
logger.info(f"Initializing LLM with model: {args.model_path}")
try:
# Configure KV cache
kv_cache_config = KvCacheConfig(
enable_block_reuse=False, # RocketKV doesn't support KV cache reuse
)
if args.rocket_sparse:
# Configure RocketKV sparse attention
sparse_attention_config = RocketSparseAttentionConfig(
window_size=args.window_size,
kernel_size=args.kernel_size,
prompt_budget=args.token_budget,
topr=args.topr,
)
logger.info(f"Using RocketKV sparse attention")
else:
sparse_attention_config = None
logger.info("Using standard attention")
# Initialize LLM
llm = LLM(
model=args.model_path,
backend=args.backend,
kv_cache_config=kv_cache_config,
attn_backend=args.attention_backend,
sparse_attention_config=sparse_attention_config,
tensor_parallel_size=args.tensor_parallel_size,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
cuda_graph_config=None,
torch_compile_config=None,
)
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
logger.info("LLM and tokenizer initialized successfully")
return llm, tokenizer
except Exception as e:
logger.error(f"Failed to initialize LLM: {e}")
raise e
def evaluate_longbench_v2(llm: LLM, tokenizer: AutoTokenizer,
args: argparse.Namespace) -> Tuple[List[Dict], float]:
"""Evaluate on LongBench v2 dataset."""
# Load LongBench v2 configuration
config = load_longbench_v2_config(args.longbench_path)
# Determine max_len for the model if not explicitly set via arguments
model_name = os.path.basename(args.model_path)
if model_name in config[
'model2maxlen']: # Use default from config if available
max_len = config['model2maxlen'][model_name]
logger.info(f"Using model-specific max_len: {max_len} for {model_name}")
else:
max_len = args.max_len
logger.info(f"Using max_len: {max_len}")
# Update args with the determined max_len
args.max_len = max_len
# Load dataset
logger.info(f"Loading LongBench v2 dataset...")
dataset = load_dataset('THUDM/LongBench-v2',
split='train',
trust_remote_code=True)
data = [item for item in dataset]
# Apply filters
filtered_data = data
if args.difficulty:
filtered_data = [
item for item in filtered_data
if item['difficulty'] == args.difficulty
]
logger.info(
f"Filtered by difficulty '{args.difficulty}': {len(filtered_data)} samples"
)
if args.length:
filtered_data = [
item for item in filtered_data if item['length'] == args.length
]
logger.info(
f"Filtered by length '{args.length}': {len(filtered_data)} samples")
if args.domain:
filtered_data = [
item for item in filtered_data if item['domain'] == args.domain
]
logger.info(
f"Filtered by domain '{args.domain}': {len(filtered_data)} samples")
# Apply start_idx and num_samples
if args.num_samples:
end_idx = min(args.start_idx + args.num_samples, len(filtered_data))
filtered_data = filtered_data[args.start_idx:end_idx]
else:
filtered_data = filtered_data[args.start_idx:]
logger.info(f"Final dataset size: {len(filtered_data)} samples")
# Determine chat template
chat_template = determine_chat_template(args.model_path, args.chat_template)
logger.info(f"Using chat template: {chat_template}")
logger.info(f"Prepare and truncate prompts...")
# Select appropriate template
if args.no_context:
template_key = '0shot_no_context'
elif args.rag > 0:
template_key = '0shot_rag'
elif args.cot:
template_key = '0shot_cot'
else:
template_key = '0shot'
template = config['templates'][template_key]
logger.info(f"Using template: {template_key}")
# Set up extra end token ids
extra_end_token_ids = []
if chat_template == "llama3":
eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
extra_end_token_ids.append(eot_id)
logger.info(f"Added llama3 end token: {eot_id}")
if chat_template == "qwen":
im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
extra_end_token_ids.append(im_end_id)
logger.info(f"Added qwen end token: {im_end_id}")
# Prepare prompts
prompts = []
for sample in filtered_data:
formatted_prompt = format_prompt(sample, template, args)
# Apply chat formatting if needed
if chat_template != 'none':
formatted_prompt = build_chat(tokenizer, formatted_prompt,
chat_template)
# Apply truncation if prompt is too long
formatted_prompt = truncate_prompt(formatted_prompt, tokenizer,
args.max_len)
prompts.append(formatted_prompt)
if len(prompts) == 0:
logger.warning(f"No prompts to evaluate")
return [], 0.0
# Run inference
logger.info(f"Starting inference for {len(prompts)} samples...")
start_time = time.time()
# Set sampling parameters
max_new_tokens = 1024 if args.cot else 256
sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0.1,
top_p=0.95,
stop_token_ids=extra_end_token_ids if extra_end_token_ids else None,
)
outputs = llm.generate(prompts, sampling_params)
inference_time = time.time() - start_time
logger.info(
f"Inference completed in {inference_time:.2f} seconds, average time per sample: {inference_time/len(prompts):.3f} seconds"
)
# Process outputs
results = []
for i, (sample, output) in enumerate(zip(filtered_data, outputs)):
prediction = output.outputs[0].text.strip()
processed_prediction = post_process(prediction, chat_template)
# Handle CoT mode
if args.cot:
# For CoT, we need to do a second inference to extract the final answer
cot_response = processed_prediction
# Format the CoT answer extraction prompt
cot_ans_template = config['templates']['0shot_cot_ans']
cot_ans_prompt = format_prompt(sample, cot_ans_template, args)
cot_ans_prompt = cot_ans_prompt.replace('$COT$', cot_response)
if chat_template != 'none':
cot_ans_prompt = build_chat(tokenizer, cot_ans_prompt,
chat_template)
# Apply truncation to CoT answer extraction prompt
cot_ans_prompt = truncate_prompt(cot_ans_prompt, tokenizer,
args.max_len)
# Generate final answer
ans_sampling_params = SamplingParams(
max_tokens=128,
temperature=0.1,
top_p=0.95,
stop_token_ids=extra_end_token_ids
if extra_end_token_ids else None,
)
ans_output = llm.generate([cot_ans_prompt], ans_sampling_params)[0]
final_prediction = post_process(ans_output.outputs[0].text.strip(),
chat_template)
extracted_answer = extract_answer(final_prediction)
else:
extracted_answer = extract_answer(processed_prediction)
# Calculate accuracy
is_correct = extracted_answer == sample[
'answer'] if extracted_answer else False
result = {
'_id': sample['_id'],
'domain': sample['domain'],
'sub_domain': sample['sub_domain'],
'difficulty': sample['difficulty'],
'length': sample['length'],
'question': sample['question'],
'choice_A': sample['choice_A'],
'choice_B': sample['choice_B'],
'choice_C': sample['choice_C'],
'choice_D': sample['choice_D'],
'answer': sample['answer'],
'prediction': processed_prediction,
'extracted_answer': extracted_answer,
'is_correct': is_correct,
'context_length': len(sample['context']),
'prompt_length': len(output.prompt_token_ids),
'output_length': len(output.outputs[0].token_ids),
}
if args.cot:
result['cot_response'] = cot_response
result['final_prediction'] = final_prediction
results.append(result)
return results, inference_time
def calculate_metrics(results: List[Dict]) -> Dict[str, Any]:
"""Calculate evaluation metrics for LongBench v2."""
if not results:
return {}
total_samples = len(results)
correct_samples = sum(1 for r in results if r['is_correct'])
overall_accuracy = correct_samples / total_samples
metrics = {
'overall_accuracy': round(overall_accuracy * 100, 2),
'total_samples': total_samples,
'correct_samples': correct_samples
}
# Calculate metrics by difficulty
difficulties = ['easy', 'hard']
for difficulty in difficulties:
diff_results = [r for r in results if r['difficulty'] == difficulty]
if diff_results:
diff_correct = sum(1 for r in diff_results if r['is_correct'])
metrics[f'{difficulty}_accuracy'] = round(
(diff_correct / len(diff_results)) * 100, 2)
# Calculate metrics by length
lengths = ['short', 'medium', 'long']
for length in lengths:
len_results = [r for r in results if r['length'] == length]
if len_results:
len_correct = sum(1 for r in len_results if r['is_correct'])
metrics[f'{length}_accuracy'] = round(
(len_correct / len(len_results)) * 100, 2)
# Calculate metrics by domain
domains = list(set(r['domain'] for r in results))
for domain in domains:
domain_results = [r for r in results if r['domain'] == domain]
if domain_results:
domain_correct = sum(1 for r in domain_results if r['is_correct'])
metrics[f'{domain}_accuracy'] = round(
(domain_correct / len(domain_results)) * 100, 2)
return metrics
def save_results(results: List[Dict], args: argparse.Namespace,
inference_time: float, output_dir: str):
"""Save evaluation results in format compatible with LongBench v2."""
os.makedirs(output_dir, exist_ok=True)
# Calculate metrics
metrics = calculate_metrics(results)
logger.info(f"Evaluation metrics: {metrics}")
# Save detailed results
results_file = os.path.join(output_dir, "longbench_v2_results.jsonl")
with open(results_file, 'w', encoding='utf-8') as f:
for result in results:
json.dump(result, f, ensure_ascii=False)
f.write('\n')
# Save prediction results in LongBench v2 format
pred_file = os.path.join(output_dir, "predictions.jsonl")
with open(pred_file, 'w', encoding='utf-8') as f:
for result in results:
pred_data = {
"_id": result['_id'],
"prediction": result['extracted_answer'],
"response": result['prediction'],
"judge": result['is_correct']
}
if args.cot:
pred_data['cot_response'] = result.get('cot_response', '')
pred_data['final_prediction'] = result.get(
'final_prediction', '')
json.dump(pred_data, f, ensure_ascii=False)
f.write('\n')
# Create summary
summary = {
'experiment_config': {
'model_path': args.model_path,
'attention_backend': args.attention_backend,
'rocket_sparse': args.rocket_sparse,
'token_budget': args.token_budget,
'cot': args.cot,
'no_context': args.no_context,
'rag': args.rag,
'difficulty_filter': args.difficulty,
'length_filter': args.length,
'domain_filter': args.domain,
'max_seq_len': args.max_seq_len,
'max_new_tokens': args.max_new_tokens
},
'evaluation_results': metrics,
'timing': {
'total_inference_time': inference_time,
'avg_inference_time':
inference_time / len(results) if results else 0,
'evaluation_timestamp': datetime.now().isoformat()
}
}
# Save summary
summary_file = os.path.join(output_dir, "summary.json")
with open(summary_file, 'w', encoding='utf-8') as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
logger.info(f"Results saved to {output_dir}")
return metrics
def main():
"""Main evaluation function."""
args = parse_arguments()
logger.set_level(args.log_level)
# Setup experiment name
if not args.exp_name:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_name = os.path.basename(args.model_path).replace('/', '_')
args.exp_name = f"longbench_v2_{model_name}_{timestamp}"
output_dir = os.path.join(args.output_dir, args.exp_name)
logger.info(
"=========== LongBench v2 Evaluation with TensorRT-LLM ===========")
os.makedirs(output_dir, exist_ok=True)
# Save configuration
config_file = os.path.join(output_dir, "config.json")
with open(config_file, 'w') as f:
json.dump(vars(args), f, indent=2)
logger.info(f"Configuration saved to {config_file}")
# Initialize LLM and tokenizer
llm, tokenizer = initialize_llm(args)
# Run evaluation
logger.info(f"Starting LongBench v2 evaluation...")
results, inference_time = evaluate_longbench_v2(llm, tokenizer, args)
# Save results and get metrics
metrics = save_results(results, args, inference_time, output_dir)
logger.info(f"{'-'*80}")
logger.info(f"Evaluation completed successfully!")
logger.info(f"Total time: {inference_time:.2f} seconds")
logger.info(f"Total samples: {len(results)}")
if metrics:
logger.info(
f"Overall accuracy: {metrics.get('overall_accuracy', 'N/A')}%")
if 'easy_accuracy' in metrics:
logger.info(
f"Easy difficulty accuracy: {metrics['easy_accuracy']}% ({metrics.get('easy_samples', 0)} samples)"
)
if 'hard_accuracy' in metrics:
logger.info(
f"Hard difficulty accuracy: {metrics['hard_accuracy']}% ({metrics.get('hard_samples', 0)} samples)"
)
for length in ['short', 'medium', 'long']:
if f'{length}_accuracy' in metrics:
logger.info(
f"{length.capitalize()} length accuracy: {metrics[f'{length}_accuracy']}% ({metrics.get(f'{length}_samples', 0)} samples)"
)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,3 @@
jieba
fuzzywuzzy
rouge

View File

@ -1,5 +1,6 @@
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from .interface import AttentionBackend, AttentionMetadata
from .sparse import get_sparse_attn_kv_cache_manager
from .trtllm import AttentionInputType, TrtllmAttention, TrtllmAttentionMetadata
from .vanilla import VanillaAttention, VanillaAttentionMetadata
@ -11,6 +12,7 @@ __all__ = [
"TrtllmAttentionMetadata",
"VanillaAttention",
"VanillaAttentionMetadata",
"get_sparse_attn_kv_cache_manager",
]
if IS_FLASHINFER_AVAILABLE:

View File

@ -140,6 +140,7 @@ class AttentionMetadata:
_saved_tensors: Dict[str, torch.Tensor] = field(init=False,
default_factory=dict)
sparse_attention_config: Optional["SparseAttentionConfig"] = None
def __post_init__(self) -> None:
if self.is_cross:
@ -563,6 +564,7 @@ class AttentionBackend(Generic[TMetadata]):
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
skip_create_weights_in_init: bool = False,
sparse_attention_config: Optional["SparseAttentionConfig"] = None,
**kwargs,
):
"""
@ -573,12 +575,14 @@ class AttentionBackend(Generic[TMetadata]):
head_dim (int): The size of each attention head (hidden_size // num_heads).
num_kv_heads (int): The number of kv heads. Defaults to num_heads if None.
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
sparse_attention_config (SparseAttentionConfig): Optional sparse attention configuration. If None, no sparse attention is applied.
"""
self.layer_idx = layer_idx
self.num_heads = num_heads
self.head_dim = head_dim
self.num_kv_heads = num_kv_heads or self.num_heads
self.quant_config = quant_config
self.sparse_attention_config = sparse_attention_config
def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
"""

View File

@ -0,0 +1,11 @@
from .utils import (get_flashinfer_sparse_attn_attention_backend,
get_sparse_attn_kv_cache_manager,
get_trtllm_sparse_attn_attention_backend,
get_vanilla_sparse_attn_attention_backend)
__all__ = [
"get_sparse_attn_kv_cache_manager",
"get_vanilla_sparse_attn_attention_backend",
"get_trtllm_sparse_attn_attention_backend",
"get_flashinfer_sparse_attn_attention_backend",
]

View File

@ -0,0 +1,308 @@
import torch
import triton
import triton.language as tl
@triton.jit
def _index_gather_kernel(output_ptr, input_ptr, index_ptr, in_row_stride,
in_token_stride, in_head_stride, idx_row_stride,
idx_token_stride, idx_head_stride, dim_size,
BLOCK_SIZE: tl.constexpr):
# get program id and block offset
row_pid = tl.program_id(0)
token_pid = tl.program_id(1)
head_pid = tl.program_id(2)
token_block_num = tl.num_programs(1)
head_num = tl.num_programs(2)
# get index
indices_idx = row_pid * idx_row_stride + token_pid * idx_token_stride + head_pid * idx_head_stride
token_idx = tl.load(index_ptr + indices_idx)
# get input and output base address
input_base = (row_pid * in_row_stride + token_idx * in_token_stride +
head_pid * in_head_stride)
output_base = (row_pid * token_block_num * head_num * dim_size +
token_pid * head_num * dim_size + head_pid * dim_size)
# process elements in blocks
for dim_offset in tl.range(0, dim_size, BLOCK_SIZE):
# get offsets
offsets = tl.arange(0, BLOCK_SIZE)
dim_indices = dim_offset + offsets
mask = dim_indices < dim_size
# load input and store output
input_val = tl.load(input_ptr + input_base + dim_indices,
mask=mask,
other=0.0)
tl.store(output_ptr + output_base + dim_indices, input_val, mask=mask)
def triton_index_gather(input, indices):
assert input.ndim == 4, "Input must be a 4D tensor, [row, token, head, dim]"
assert indices.ndim == 3, "Indices must be a 3D tensor, [row, token, head]"
# shape of input and indices
row_size = input.shape[0]
head_num = input.shape[2]
dim_size = input.shape[3]
num_tokens = indices.shape[1]
# create output tensor
output = torch.empty((row_size, num_tokens, head_num, dim_size),
device='cuda',
dtype=input.dtype)
# launch kernel
grid = (row_size, num_tokens, head_num)
_index_gather_kernel[grid](output,
input,
indices,
input.stride(0),
input.stride(1),
input.stride(2),
indices.stride(0),
indices.stride(1),
indices.stride(2),
dim_size,
BLOCK_SIZE=1024)
return output
@triton.jit
def _update_kt_cache_ctx_kernel(k_ptr, cache_ptr, block_offsets_ptr,
cum_seq_lens_ptr, cum_kt_seq_lens_ptr,
token_to_batch_map_ptr, num_kv_heads, dim_size,
kt_page_size, tokens_per_block,
max_kt_blocks_per_seq,
BLOCK_SIZE: tl.constexpr):
# get program id
kt_token_idx = tl.program_id(0)
# get params
batch_idx = tl.load(token_to_batch_map_ptr + kt_token_idx)
kv_start_idx = tl.load(cum_seq_lens_ptr + batch_idx)
kv_end_idx = tl.load(cum_seq_lens_ptr + batch_idx + 1)
kt_start_idx = tl.load(cum_kt_seq_lens_ptr + batch_idx)
local_kt_token_idx = kt_token_idx - kt_start_idx
global_kv_token_idx = kv_start_idx + local_kt_token_idx * kt_page_size
# get offsets
hidden_size = num_kv_heads * dim_size
k_base = k_ptr + global_kv_token_idx * hidden_size
block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block
block_idx = tl.load(block_offsets_ptr + block_offset)
token_idx_in_block = local_kt_token_idx % tokens_per_block
cache_base = cache_ptr + (block_idx * tokens_per_block +
token_idx_in_block) * hidden_size * 2
# compute min/max and store kt
for hidden_start in tl.range(0, hidden_size, BLOCK_SIZE):
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
head_idx = hidden_indices // dim_size
dim_idx = hidden_indices % dim_size
dim_mask = hidden_indices < hidden_size
# get k_min and k_max
k_min = tl.full((BLOCK_SIZE, ), float('inf'), dtype=tl.float32)
k_max = tl.full((BLOCK_SIZE, ), float('-inf'), dtype=tl.float32)
for i in range(kt_page_size):
if global_kv_token_idx + i < kv_end_idx:
k = tl.load(k_base + i * hidden_size + hidden_indices,
mask=dim_mask,
other=0.0)
k_min = tl.minimum(k_min, k)
k_max = tl.maximum(k_max, k)
k_min = k_min.to(cache_ptr.dtype.element_ty)
k_max = k_max.to(cache_ptr.dtype.element_ty)
# store k_min and k_max to cache
k_min_offset = cache_base + head_idx * dim_size * 2 + dim_idx
k_max_offset = k_min_offset + dim_size
tl.store(k_min_offset, k_min, mask=dim_mask)
tl.store(k_max_offset, k_max, mask=dim_mask)
@triton.jit
def _update_kt_cache_gen_kernel(k_ptr, cache_ptr, block_offsets_ptr,
seq_lens_ptr, num_kv_heads, dim_size,
kt_page_size, tokens_per_block,
max_kt_blocks_per_seq,
BLOCK_SIZE: tl.constexpr):
# get program id
batch_idx = tl.program_id(0)
head_idx = tl.program_id(1)
# get params
past_key_value_length = tl.load(seq_lens_ptr + batch_idx) - 1
kt_token_idx = past_key_value_length // kt_page_size
kt_token_idx_in_page = past_key_value_length % kt_page_size
# get offsets
hidden_size = num_kv_heads * dim_size
k_base = k_ptr + batch_idx * hidden_size + head_idx * dim_size
block_offset = batch_idx * max_kt_blocks_per_seq + kt_token_idx // tokens_per_block
block_idx = tl.load(block_offsets_ptr + block_offset)
kt_token_idx_in_block = kt_token_idx % tokens_per_block
cache_base = cache_ptr + (block_idx * tokens_per_block +
kt_token_idx_in_block) * hidden_size * 2
cache_base += head_idx * dim_size * 2
# update kt cache
for hidden_start in tl.range(0, dim_size, BLOCK_SIZE):
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
dim_mask = hidden_indices < dim_size
# load k
k = tl.load(k_base + hidden_indices, mask=dim_mask, other=0.0)
# load kt cache
kt_mask = dim_mask & (kt_token_idx_in_page > 0)
k_min = tl.load(cache_base + hidden_indices,
mask=kt_mask,
other=float('inf'))
k_max = tl.load(cache_base + hidden_indices + dim_size,
mask=kt_mask,
other=float('-inf'))
k_min = tl.minimum(k_min, k)
k_max = tl.maximum(k_max, k)
k_min = k_min.to(cache_ptr.dtype.element_ty)
k_max = k_max.to(cache_ptr.dtype.element_ty)
# store k_min and k_max to cache
tl.store(cache_base + hidden_indices, k_min, mask=dim_mask)
tl.store(cache_base + hidden_indices + dim_size, k_max, mask=dim_mask)
@triton.jit
def _load_kt_cache_kernel(kt_states_ptr, cache_ptr, block_offsets_ptr,
cum_kt_seq_lens_ptr, token_to_batch_map_ptr,
num_kv_heads, dim_size, tokens_per_block,
max_kt_blocks_per_seq, BLOCK_SIZE: tl.constexpr):
# get program id
kt_token_idx = tl.program_id(0)
# get params
batch_idx = tl.load(token_to_batch_map_ptr + kt_token_idx)
kt_start_idx = tl.load(cum_kt_seq_lens_ptr + batch_idx)
local_kt_token_idx = kt_token_idx - kt_start_idx
# get offsets
hidden_size = num_kv_heads * dim_size * 2
kt_states_base = kt_states_ptr + kt_token_idx * hidden_size
block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block
block_idx = tl.load(block_offsets_ptr + block_offset)
token_idx_in_block = local_kt_token_idx % tokens_per_block
cache_base = cache_ptr + (block_idx * tokens_per_block +
token_idx_in_block) * hidden_size
# load kt cache
for hidden_start in tl.range(0, hidden_size, BLOCK_SIZE):
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
mask = hidden_indices < hidden_size
# load kt cache
kt = tl.load(cache_base + hidden_indices, mask=mask, other=0.0)
# store kt to kt_states
tl.store(kt_states_base + hidden_indices, kt, mask=mask)
def triton_update_kt_cache(k,
kt_cache_tensor,
kt_cache_block_offsets,
seq_lens,
kt_page_size,
tokens_per_block,
max_kt_blocks_per_seq,
update=True):
# inputs:
# k: (total_seq_len, num_kv_heads, head_dim)
# kt_cache_tensor: (num_blocks, tokens_per_block, num_kv_heads, 2 * head_dim)
# kt_cache_block_offsets: (max_batch_size, max_kt_blocks_per_seq)
# seq_lens: (batch_size)
# kt_page_size: int
# update: bool
# outputs:
# kt_states: (total_kt_tokens, num_kv_heads, 2 * head_dim)
# params
batch_size = seq_lens.size(0)
num_kv_heads = k.size(1)
head_dim = k.size(2)
tokens_per_block = kt_cache_tensor.size(1)
num_kt_tokens = (seq_lens + kt_page_size - 1) // kt_page_size
# context
if not update:
total_num_kt_tokens = num_kt_tokens.sum().item()
cum_seq_lens = torch.cumsum(torch.cat([
torch.zeros(1, device='cuda', dtype=torch.long),
seq_lens.to(torch.long)
]),
dim=0)
cum_kt_seq_lens = torch.cumsum(torch.cat([
torch.zeros(1, device='cuda', dtype=torch.long),
num_kt_tokens.to(torch.long)
]),
dim=0)
token_to_batch_map = torch.repeat_interleave(
torch.arange(batch_size,
device='cuda'), repeats=num_kt_tokens).to(torch.long)
grid = (total_num_kt_tokens, )
_update_kt_cache_ctx_kernel[grid](k,
kt_cache_tensor,
kt_cache_block_offsets,
cum_seq_lens,
cum_kt_seq_lens,
token_to_batch_map,
num_kv_heads,
head_dim,
kt_page_size,
tokens_per_block,
max_kt_blocks_per_seq,
BLOCK_SIZE=1024)
return
else:
# generation
# update kt cache
grid = (batch_size, num_kv_heads)
_update_kt_cache_gen_kernel[grid](k,
kt_cache_tensor,
kt_cache_block_offsets,
seq_lens,
num_kv_heads,
head_dim,
kt_page_size,
tokens_per_block,
max_kt_blocks_per_seq,
BLOCK_SIZE=1024)
# load kt cache
total_num_kt_tokens = num_kt_tokens.sum().item()
kt_states = torch.empty(
(total_num_kt_tokens, num_kv_heads, 2 * head_dim),
device='cuda',
dtype=k.dtype)
token_to_batch_map = torch.repeat_interleave(
torch.arange(batch_size,
device='cuda'), repeats=num_kt_tokens).to(torch.long)
cum_kt_seq_lens = torch.cumsum(torch.cat([
torch.zeros(1, device='cuda', dtype=torch.long),
num_kt_tokens.to(torch.long)
]),
dim=0)
grid = (total_num_kt_tokens, )
_load_kt_cache_kernel[grid](kt_states,
kt_cache_tensor,
kt_cache_block_offsets,
cum_kt_seq_lens,
token_to_batch_map,
num_kv_heads,
head_dim,
tokens_per_block,
max_kt_blocks_per_seq,
BLOCK_SIZE=1024)
return kt_states

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,39 @@
from .rocket import (RocketKVCacheManager, RocketTrtllmAttention,
RocketVanillaAttention)
def get_sparse_attn_kv_cache_manager(
sparse_attn_config: "SparseAttentionConfig"):
if sparse_attn_config.algorithm == "rocket":
return RocketKVCacheManager
else:
raise ValueError(
f"Unsupported sparse attention algorithm: {sparse_attn_config.algorithm}"
)
def get_vanilla_sparse_attn_attention_backend(
sparse_attn_config: "SparseAttentionConfig"):
if sparse_attn_config.algorithm == "rocket":
return RocketVanillaAttention
else:
raise ValueError(
f"Unsupported sparse attention algorithm in vanilla attention backend: {sparse_attn_config.algorithm}"
)
def get_trtllm_sparse_attn_attention_backend(
sparse_attn_config: "SparseAttentionConfig"):
if sparse_attn_config.algorithm == "rocket":
return RocketTrtllmAttention
else:
raise ValueError(
f"Unsupported sparse attention algorithm in trtllm attention backend: {sparse_attn_config.algorithm}"
)
def get_flashinfer_sparse_attn_attention_backend(
sparse_attn_config: "SparseAttentionConfig"):
raise ValueError(
f"Unsupported sparse attention algorithm in flashinfer attention backend: {sparse_attn_config.algorithm}"
)

View File

@ -190,6 +190,10 @@ class TrtllmAttentionWrapper:
spec_decoding_generation_lengths: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
chunked_prefill_buffer_batch_size: int = 1,
sparse_kv_indices: Optional[torch.Tensor] = None,
sparse_kv_offsets: Optional[torch.Tensor] = None,
sparse_attn_indices: Optional[torch.Tensor] = None,
sparse_attn_offsets: Optional[torch.Tensor] = None,
**kwargs,
):
"""
@ -229,6 +233,10 @@ class TrtllmAttentionWrapper:
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU.
chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens.
sparse_kv_indices (torch.Tensor): The sparse indices for the KV cache, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
sparse_kv_offsets (torch.Tensor): The batch offsets for the sparse KV indices, with shape of (num_contexts + 1) on GPU.
sparse_attn_indices (torch.Tensor): The sparse indices for the attention layer, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU.
"""
self.layer_idx = layer_idx
self.tokens_per_block = tokens_per_block
@ -266,7 +274,10 @@ class TrtllmAttentionWrapper:
self.softmax_stats_tensor = softmax_stats_tensor
self.helix_position_offsets = helix_position_offsets
self.attention_sinks = attention_sinks
self.sparse_kv_indices = sparse_kv_indices
self.sparse_kv_offsets = sparse_kv_offsets
self.sparse_attn_indices = sparse_attn_indices
self.sparse_attn_offsets = sparse_attn_offsets
if max_sequence_length > self.rope_params.max_positions:
self.rope_params.max_positions = max_sequence_length
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
@ -424,6 +435,12 @@ class TrtllmAttentionWrapper:
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
]
mla_tensor_params = [self.helix_position_offsets]
sparse_attention_params = [
self.sparse_kv_indices,
self.sparse_kv_offsets,
self.sparse_attn_indices,
self.sparse_attn_offsets,
]
thop.attention(
q,
@ -491,6 +508,7 @@ class TrtllmAttentionWrapper:
self.softmax_stats_tensor,
spec_decoding_bool_params,
spec_decoding_tensor_params,
sparse_attention_params,
)
# reset the planned states (especially tensors) to avoid memory leak
self.plan()
@ -1239,6 +1257,14 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
use_paged_context_fmha=use_paged_context_fmha,
is_mla_enable=self.is_mla_enable,
)
sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None
if self.sparse_attention_config is not None:
sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(
q, k, metadata)
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
q, k, metadata)
self.wrapper.plan(
layer_idx=self.get_local_layer_idx(metadata),
tokens_per_block=metadata.tokens_per_block,
@ -1287,6 +1313,10 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
spec_decoding_generation_lengths,
attention_sinks=attention_sinks,
chunked_prefill_buffer_batch_size=chunked_prefill_buffer_batch_size,
sparse_kv_indices=sparse_kv_indices,
sparse_kv_offsets=sparse_kv_offsets,
sparse_attn_indices=sparse_attn_indices,
sparse_attn_offsets=sparse_attn_offsets,
)
out_dtype = None
if out_scale is not None:
@ -1500,3 +1530,27 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
self.num_heads,
self.mla_params.v_head_dim,
)
def sparse_attn_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse attn indices. It's implemented in the derived class.
"""
raise NotImplementedError
def sparse_kv_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse kv indices. It's implemented in the derived class.
"""
raise NotImplementedError

View File

@ -3,22 +3,33 @@ from typing import Optional, Type
from ...models.modeling_utils import QuantConfig
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from .interface import AttentionBackend, MLAParams, PositionalEmbeddingParams
from .sparse import (get_flashinfer_sparse_attn_attention_backend,
get_trtllm_sparse_attn_attention_backend,
get_vanilla_sparse_attn_attention_backend)
from .trtllm import TrtllmAttention
from .vanilla import VanillaAttention
def get_attention_backend(backend_name: str) -> Type[AttentionBackend]:
def get_attention_backend(
backend_name: str,
sparse_attn_config: Optional["SparseAttentionConfig"] = None
) -> Type[AttentionBackend]:
if backend_name == "VANILLA":
if sparse_attn_config is not None:
return get_vanilla_sparse_attn_attention_backend(sparse_attn_config)
return VanillaAttention
elif backend_name == "TRTLLM":
if sparse_attn_config is not None:
return get_trtllm_sparse_attn_attention_backend(sparse_attn_config)
return TrtllmAttention
elif backend_name == "FLASHINFER" and IS_FLASHINFER_AVAILABLE:
from .flashinfer import FlashInferAttention
if sparse_attn_config is not None:
return get_flashinfer_sparse_attn_attention_backend(
sparse_attn_config)
return FlashInferAttention
elif backend_name == "FLASHINFER_STAR_ATTENTION" and IS_FLASHINFER_AVAILABLE:
from .star_flashinfer import StarAttention
return StarAttention
return TrtllmAttention
@ -42,12 +53,13 @@ def create_attention(
predicted_tokens_per_seq: Optional[int] = 1,
skip_create_weights_in_init: bool = False,
attention_chunk_size: Optional[int] = None,
sparse_attention_config: Optional["SparseAttentionConfig"] = None,
):
if attention_chunk_size is not None and backend_name.upper() != "TRTLLM":
raise ValueError(
f"Backend {backend_name} does not support chunked attention.")
attn_cls = get_attention_backend(backend_name)
attn_cls = get_attention_backend(backend_name, sparse_attention_config)
if is_mla_enable:
assert attn_cls.support_mla(
@ -76,4 +88,5 @@ def create_attention(
mla_params=mla_params,
skip_create_weights_in_init=skip_create_weights_in_init,
attention_chunk_size=attention_chunk_size,
sparse_attention_config=sparse_attention_config,
)

View File

@ -13,6 +13,7 @@ except ImportError:
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
PredefinedAttentionMask)
from .sparse.kernel import triton_index_gather
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -94,90 +95,157 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.q_scaling = q_scaling
def _single_request_update_kv_cache(self, k, v, kv_cache_tensor, seq_len,
cache_idx, cache_position):
def _single_request_sparse_attn_predict(
self, q: torch.Tensor, k: Optional[torch.Tensor],
v: Optional[torch.Tensor], kv_cache_tensor: torch.Tensor,
metadata: AttentionMetadata, past_seen_token: int, sample_idx: int,
**kwargs) -> tuple[Optional[torch.Tensor], int]:
raise NotImplementedError
def _single_request_sparse_kv_predict(
self, q: Optional[torch.Tensor], k: Optional[torch.Tensor],
v: Optional[torch.Tensor], metadata: AttentionMetadata,
past_seen_token: int, sample_idx: int,
**kwargs) -> tuple[Optional[torch.Tensor], int]:
raise NotImplementedError
def _single_request_update_kv_cache(self,
k,
v,
kv_cache_tensor,
past_seen_token,
kv_len,
cache_idx,
sparse_kv_indices=None):
# select tokens using the sparse kv indices
if sparse_kv_indices is not None:
k_selected = triton_index_gather(k, sparse_kv_indices)
v_selected = triton_index_gather(v, sparse_kv_indices)
else:
k_selected, v_selected = k, v
# get cache position
seq_len = past_seen_token + kv_len
cache_position = torch.arange(past_seen_token,
seq_len,
device=kv_cache_tensor.device)
# get kv cache tensor
k_out = kv_cache_tensor[cache_idx, 0, :, :, :].unsqueeze(0)
v_out = kv_cache_tensor[cache_idx, 1, :, :, :].unsqueeze(0)
# update kv cache
if k is not None and v is not None:
access_type = self._access_type[k.dtype.itemsize]
k_out.view(dtype=access_type).index_copy_(1, cache_position,
k.view(dtype=access_type))
v_out.view(dtype=access_type).index_copy_(1, cache_position,
v.view(dtype=access_type))
access_type = self._access_type[k_selected.dtype.itemsize]
k_out.view(dtype=access_type).index_copy_(
1, cache_position, k_selected.view(dtype=access_type))
v_out.view(dtype=access_type).index_copy_(
1, cache_position, v_selected.view(dtype=access_type))
return k_out[:, :seq_len, :, :], v_out[:, :seq_len, :, :]
def _single_request_forward(self,
q,
k,
v,
attention_mask: AttentionMask,
kv_cache_tensor,
past_seen_token,
cache_idx,
attention_window_size: Optional[int] = None):
# return past kv and the dense kv tensors for sparse attention
if sparse_kv_indices is not None:
k_states = torch.cat([k_out[:, :past_seen_token, :, :], k], dim=1)
v_states = torch.cat([v_out[:, :past_seen_token, :, :], v], dim=1)
else:
k_states, v_states = k_out[:, :seq_len, :, :], v_out[:, :
seq_len, :, :]
return k_states, v_states
def _single_request_preprocess_inputs(self, q, k, v, kv_dtype):
bsz = 1
q_len = q.size(0)
# Query
q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# Key and Value
target_seq_len = past_seen_token
kv_len = 0
if k is not None and v is not None:
kv_len = k.size(0)
k = k.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
v = v.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
target_seq_len += kv_len
if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
):
qc = self.quant_config
if qc.layer_quant_mode.has_fp8_kv_cache():
assert kv_cache_tensor.dtype == torch.float8_e4m3fn, f"KV cache should have fp8 dtype, but get {kv_cache_tensor.dtype}"
assert kv_dtype == torch.float8_e4m3fn, \
f"KV cache should have fp8 dtype, but get {kv_dtype}"
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
assert k.dtype == v.dtype == kv_cache_tensor.dtype, f"KV cache dtype {kv_cache_tensor.dtype} does not match k/v dtype {k.dtype}/{v.dtype}"
assert k.dtype == v.dtype == kv_dtype, \
f"KV cache dtype {kv_dtype} does not match k/v dtype {k.dtype}/{v.dtype}"
cache_position = torch.arange(past_seen_token,
target_seq_len,
device=q.device)
return q, k, v, kv_len
key_states, value_states = self._single_request_update_kv_cache(
k, v, kv_cache_tensor, target_seq_len, cache_idx, cache_position)
def _single_request_create_attention_mask(self,
attention_mask,
past_seen_token,
kv_len,
q_device,
q_len,
attention_window_size=None):
"""
Create appropriate attention mask based on the attention type.
key_states = key_states.transpose(1, 2).to(q.dtype)
value_states = value_states.transpose(1, 2).to(q.dtype)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Attention Mask
Returns:
Tuple of (is_causal, attn_mask)
"""
bsz = 1
is_causal = False
attn_mask = None
# get cache position
seq_len = past_seen_token + kv_len
cache_position = torch.arange(past_seen_token, seq_len, device=q_device)
# create attention mask
if attention_mask == PredefinedAttentionMask.CAUSAL:
# Create custom sliding window mask as sdpa doesn't natively support it.
if attention_window_size is not None:
attn_mask = generate_sliding_window_mask(
bsz, target_seq_len, cache_position, q.device,
bsz, seq_len, cache_position, q_device,
attention_window_size)
elif past_seen_token == 0:
is_causal = True
elif q_len != 1:
# attn_mask: 4-D tensor (batch_size, 1, query_seq_len, seq_len)
attn_mask = generate_causal_mask(bsz, target_seq_len,
cache_position, q.device)
attn_mask = generate_causal_mask(bsz, seq_len, cache_position,
q_device)
elif attention_mask == PredefinedAttentionMask.FULL:
pass
else:
raise ValueError("Unexpected attention mask type")
return attn_mask, is_causal
def _single_request_attn_forward(self,
q,
key_states,
value_states,
is_causal,
attn_mask,
sparse_indices=None):
"""
Common attention computation using scaled dot-product attention.
"""
# select the key and value states using the sparse indices
if sparse_indices is not None:
key_states = triton_index_gather(key_states, sparse_indices)
value_states = triton_index_gather(value_states, sparse_indices)
# transpose kv
key_states = key_states.transpose(1, 2).to(q.dtype)
value_states = value_states.transpose(1, 2).to(q.dtype)
# repeat kv to support MQA/GQA
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# get qk scale
qk_scale = None
if self.q_scaling is not None:
qk_scale = 1 / (math.sqrt(self.head_dim) * self.q_scaling)
# attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
key_states,
@ -187,21 +255,66 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
scale=qk_scale,
)
attn_output = attn_output.squeeze(0)
return attn_output
@staticmethod
def _single_request_forward(self,
q,
k,
v,
attention_mask: AttentionMask,
kv_cache_tensor,
past_seen_token,
cache_idx,
sample_idx,
metadata: AttentionMetadata,
attention_window_size: Optional[int] = None,
**kwargs):
# preprocess inputs
q, k, v, kv_len = self._single_request_preprocess_inputs(
q, k, v, kv_cache_tensor.dtype)
# predict sparse kv indices
sparse_kv_indices = None
if self.sparse_attention_config is not None:
sparse_kv_indices, kv_len = self._single_request_sparse_kv_predict(
q, k, v, metadata, past_seen_token, sample_idx)
# update kv cache
key_states, value_states = self._single_request_update_kv_cache(
k, v, kv_cache_tensor, past_seen_token, kv_len, cache_idx,
sparse_kv_indices)
# predict sparse attn indices
sparse_indices = None
if self.sparse_attention_config is not None:
sparse_indices, kv_len = self._single_request_sparse_attn_predict(
q, k, v, kv_cache_tensor, metadata, past_seen_token, sample_idx)
# create attention mask
attn_mask, is_causal = self._single_request_create_attention_mask(
attention_mask, past_seen_token, kv_len, q.device, q.size(2),
attention_window_size)
# attention
attn_output = self._single_request_attn_forward(q, key_states,
value_states, is_causal,
attn_mask,
sparse_indices)
return attn_output.squeeze(0)
def no_kv_cache_forward(
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
num_heads: int,
num_kv_heads: int,
metadata: AttentionMetadata,
*,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
num_heads: int,
num_kv_heads: int,
metadata: AttentionMetadata,
*,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
position_ids: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
This function is used to perform attention without kv cache.
Args:
@ -275,14 +388,14 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
# try to separate the kv cache estimation path from no kv cache attn.
num_heads = self.num_heads
num_kv_heads = self.num_kv_heads
return VanillaAttention.no_kv_cache_forward(
q=q,
k=k,
v=v,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
metadata=metadata,
attention_mask=attention_mask)
return self.no_kv_cache_forward(q=q,
k=k,
v=v,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
metadata=metadata,
attention_mask=attention_mask,
**kwargs)
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
cache_indices = [
@ -298,7 +411,7 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
offset = 0
offset_kv = 0
attn_outputs = []
for i, (seq_len, seq_len_kv) in enumerate(
for sample_idx, (seq_len, seq_len_kv) in enumerate(
zip(metadata.seq_lens, metadata.seq_lens_kv)):
single_q = q[offset:offset + seq_len]
single_k = k[
@ -306,13 +419,16 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
seq_len_kv] if k is not None and seq_len_kv != 0 else None
single_v = v[
offset_kv:offset_kv +
seq_len_kv] if k is not None and seq_len_kv != 0 else None
past_seen_token = past_seen_tokens[i]
cache_idx = cache_indices[i]
seq_len_kv] if v is not None and seq_len_kv != 0 else None
past_seen_token = past_seen_tokens[sample_idx]
cache_idx = cache_indices[sample_idx]
attn_output = self._single_request_forward(
single_q, single_k, single_v, attention_mask, kv_cache_tensor,
past_seen_token, cache_idx, attention_window_size)
past_seen_token, cache_idx, sample_idx, metadata,
attention_window_size, **kwargs)
attn_outputs.append(attn_output)
offset += seq_len

View File

@ -121,6 +121,7 @@ class ModelConfig(Generic[TConfig]):
spec_config: Optional["DecodingBaseConfig"] = None
lora_config: Optional["LoraConfig"] = None
sparse_attention_config: Optional["SparseAttentionConfig"] = None
is_generation: bool = True
max_num_tokens: int = 8192

View File

@ -259,7 +259,9 @@ class Attention(nn.Module):
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
attn_cls = get_attention_backend(self.attn_backend)
attn_cls = get_attention_backend(
self.attn_backend,
sparse_attn_config=config.sparse_attention_config)
# These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used,
# but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora
@ -277,6 +279,9 @@ class Attention(nn.Module):
# Whether to fuse RoPE into the attention OP.
# If true, RoPE will be applied in self.attn.forward.
# If false, RoPE will be applied in self.apply_rope.
if config.sparse_attention_config is not None:
logger.warning("disable rope_fusion for sparse attention.")
rope_fusion = False
self.rope_fusion = rope_fusion
if self.rope_fusion and not attn_cls.support_fused_rope():
logger.warning(
@ -314,6 +319,7 @@ class Attention(nn.Module):
skip_create_weights_in_init=config.skip_create_weights_in_init,
q_scaling=self.q_scaling,
attention_chunk_size=self.attention_chunk_size,
sparse_attention_config=config.sparse_attention_config,
)
self.support_fused_qkv = self.attn.support_fused_qkv()
@ -854,6 +860,7 @@ class MLA(nn.Module):
v_head_dim=self.v_head_dim,
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
skip_create_weights_in_init=config.skip_create_weights_in_init,
sparse_attention_config=config.sparse_attention_config,
)
self.mqa = create_attention(
@ -873,6 +880,7 @@ class MLA(nn.Module):
v_head_dim=self.kv_lora_rank,
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
skip_create_weights_in_init=config.skip_create_weights_in_init,
sparse_attention_config=config.sparse_attention_config,
)
self.aux_stream = aux_stream

View File

@ -1,6 +1,5 @@
import os
import random
from collections.abc import Iterable
from typing import Dict, List, Optional
import torch
@ -12,14 +11,15 @@ from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
MTPDecodingConfig, PeftCacheConfig,
SamplerType, SpeculativeConfig,
TorchLlmArgs)
SamplerType, SparseAttentionConfig,
SpeculativeConfig, TorchLlmArgs)
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from tensorrt_llm.lora_manager import load_torch_lora
from tensorrt_llm.mapping import CpType, Mapping
from ..attention_backend import get_sparse_attn_kv_cache_manager
from ..model_config import ModelConfig
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
from .config import PyTorchConfig
@ -42,6 +42,20 @@ from .seq_slot_manager import SeqSlotManager
GB = 1 << 30
def get_kv_cache_manager_cls(model_config: ModelConfig):
config = model_config.pretrained_config
sparse_attn_config = model_config.sparse_attention_config
if is_mla(config):
return KVCacheManager
elif is_nemotron_hybrid(config):
return MambaHybridCacheManager
else:
if sparse_attn_config is not None:
return get_sparse_attn_kv_cache_manager(sparse_attn_config)
else:
return KVCacheManager
class KvCacheCreator:
"""Groups together logic related to KV cache construction."""
@ -61,6 +75,7 @@ class KvCacheCreator:
kv_cache_config: KvCacheConfig,
pytorch_backend_config: PyTorchConfig,
speculative_config: SpeculativeConfig,
sparse_attention_config: SparseAttentionConfig,
):
self._model_engine = model_engine
self._draft_model_engine = draft_model_engine
@ -72,50 +87,14 @@ class KvCacheCreator:
self._kv_connector_manager = kv_connector_manager
self._pytorch_backend_config = pytorch_backend_config
self._speculative_config = speculative_config
self._sparse_attention_config = sparse_attention_config
self._tokens_per_block = tokens_per_block
self._max_seq_len = max_seq_len
self._max_batch_size = max_batch_size
self._net_max_seq_len = net_max_seq_len
self._dummy_reqs = None
@staticmethod
def _get_cache_size_per_token(model_config: ModelConfig,
mapping: Mapping) -> int:
mem_per_token = 2
quant_config = model_config.quant_config
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
):
mem_per_token = 1
config = model_config.pretrained_config
num_key_value_heads = getattr(config, 'num_key_value_heads',
config.num_attention_heads)
if isinstance(num_key_value_heads, Iterable):
num_key_value_heads = sum(num_key_value_heads) / len(
num_key_value_heads)
mla = is_mla(config)
tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size
kv_factor = 2
if mla:
# MLA has kv_lora_rank and qk_rope_head_dim
head_dim = config.kv_lora_rank + config.qk_rope_head_dim
kv_factor = 1
else:
_head_dim = getattr(config, 'head_dim', None)
if not isinstance(_head_dim, int):
_head_dim = config.hidden_size // config.num_attention_heads
head_dim = _head_dim * num_key_value_heads // tp_size
# provide at least 1 layer to prevent division by zero cache size
num_attention_layers = max(
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
mem_per_token *= num_attention_layers * head_dim
# K and V
mem_per_token *= kv_factor
return mem_per_token
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
model_engine.model.model_config)
def _get_free_gpu_memory_fraction(self) -> float:
fraction = self._kv_cache_config.free_gpu_memory_fraction
@ -126,12 +105,14 @@ class KvCacheCreator:
def _get_kv_size_per_token(self):
model_config = self._model_engine.model.model_config
mapping = self._mapping
kv_size_per_token = self._get_cache_size_per_token(
model_config, mapping)
kv_size_per_token = self._kv_cache_manager_cls.get_cache_size_per_token(
model_config, mapping, tokens_per_block=self._tokens_per_block)
if self._draft_model_engine is not None:
draft_model_config = self._draft_model_engine.model.model_config
kv_size_per_token += self._get_cache_size_per_token(
draft_model_config, mapping)
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
draft_model_config,
mapping,
tokens_per_block=self._tokens_per_block)
return kv_size_per_token
def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
@ -231,6 +212,12 @@ class KvCacheCreator:
estimating_kv_cache = True
self._kv_cache_config.max_tokens = self._get_token_num_for_estimation(
)
model_config = self._model_engine.model.model_config
if model_config.attn_backend == "VANILLA":
logger.info(
"KV cache size estimation is not supported for Vanilla attention backend, disable it."
)
estimating_kv_cache = False
return estimating_kv_cache
def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
@ -374,6 +361,7 @@ class KvCacheCreator:
config = model_engine.model.model_config.pretrained_config
quant_config = model_engine.model.model_config.quant_config
spec_config = self._speculative_config
sparse_attn_config = self._sparse_attention_config
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
@ -396,7 +384,7 @@ class KvCacheCreator:
num_hidden_layers = config.num_hidden_layers
if is_mla(config):
kv_cache_manager = KVCacheManager(
kv_cache_manager = self._kv_cache_manager_cls(
self._kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.
SELFKONLY,
@ -434,8 +422,7 @@ class KvCacheCreator:
mamba_layer_mask = [
char == "M" for char in config.hybrid_override_pattern
]
kv_cache_manager = MambaHybridCacheManager(
kv_cache_manager = self._kv_cache_manager_cls(
# mamba cache parameters
config.ssm_state_size,
config.conv_kernel,
@ -518,7 +505,7 @@ class KvCacheCreator:
binding_model_config = model_engine.model.model_config.get_bindings_model_config(
tokens_per_block=self._tokens_per_block) if is_vswa else None
kv_cache_manager = KVCacheManager(
kv_cache_manager = self._kv_cache_manager_cls(
self._kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_hidden_layers,
@ -536,6 +523,7 @@ class KvCacheCreator:
is_draft=model_engine.is_draft_model,
kv_connector_manager=self._kv_connector_manager
if not estimating_kv_cache else None,
sparse_attn_config=sparse_attn_config,
)
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:

View File

@ -137,6 +137,7 @@ class PyTorchModelEngine(ModelEngine):
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
dist: Optional[MPIDist] = None,
spec_config: Optional["DecodingBaseConfig"] = None,
sparse_attention_config: Optional["SparseAttentionConfig"] = None,
lora_config: Optional[LoraConfig] = None,
is_draft_model: bool = False,
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
@ -164,6 +165,7 @@ class PyTorchModelEngine(ModelEngine):
spec_config.max_draft_len = 0
self.spec_config = spec_config
self.is_spec_decode = spec_config is not None
self.sparse_attention_config = sparse_attention_config
self.enable_spec_decode = self.is_spec_decode
self.is_draft_model = is_draft_model
@ -175,6 +177,7 @@ class PyTorchModelEngine(ModelEngine):
pytorch_backend_config=pytorch_backend_config,
mapping=self.mapping,
spec_config=self.spec_config,
sparse_attention_config=self.sparse_attention_config,
max_num_tokens=max_num_tokens,
max_seq_len=max_seq_len,
lora_config=lora_config,
@ -261,7 +264,8 @@ class PyTorchModelEngine(ModelEngine):
self.is_warmup = False
self.attn_backend = get_attention_backend(
pytorch_backend_config.attn_backend)
pytorch_backend_config.attn_backend,
sparse_attn_config=sparse_attention_config)
if self.is_spec_decode:
self.spec_metadata = None
@ -794,7 +798,8 @@ class PyTorchModelEngine(ModelEngine):
enable_flash_mla=self.model.model_config.enable_flash_mla,
enable_context_mla_with_cached_kv=
enable_context_mla_with_cached_kv,
cache_indirection=cache_indirection)
cache_indirection=cache_indirection,
sparse_attention_config=self.sparse_attention_config)
if self.attn_metadata is not None:
# This assertion can be relaxed if needed: just create a new metadata
@ -811,7 +816,8 @@ class PyTorchModelEngine(ModelEngine):
runtime_features=self.attn_runtime_features,
enable_flash_mla=self.model.model_config.enable_flash_mla,
enable_context_mla_with_cached_kv=enable_context_mla_with_cached_kv,
cache_indirection=cache_indirection)
cache_indirection=cache_indirection,
sparse_attention_config=self.sparse_attention_config)
return self.attn_metadata

View File

@ -160,6 +160,7 @@ class ModelLoader:
pytorch_backend_config: PyTorchConfig,
mapping: Mapping,
spec_config: Optional["DecodingBaseConfig"],
sparse_attention_config: Optional["SparseAttentionConfig"],
max_num_tokens: int,
max_seq_len: Optional[int],
lora_config: Optional[LoraConfig] = None):
@ -177,6 +178,7 @@ class ModelLoader:
self.pytorch_backend_config = pytorch_backend_config
self.mapping = mapping
self.spec_config = spec_config
self.sparse_attention_config = sparse_attention_config
self.max_num_tokens = max_num_tokens
self.max_seq_len = max_seq_len
self.lora_config = lora_config
@ -294,6 +296,7 @@ class ModelLoader:
force_dynamic_quantization=self.pytorch_backend_config.
force_dynamic_quantization,
spec_config=self.spec_config,
sparse_attention_config=self.sparse_attention_config,
max_num_tokens=self.max_num_tokens,
max_seq_len=self.max_seq_len,
moe_max_num_tokens=self.pytorch_backend_config.moe_max_num_tokens,

View File

@ -252,6 +252,8 @@ def create_py_executor(
max_num_tokens = 8192
tokens_per_block = kv_cache_config.tokens_per_block
if pytorch_backend_config.attn_backend == "VANILLA":
tokens_per_block = max_num_tokens
if pytorch_backend_config.attn_backend in [
"FLASHINFER", "FLASHINFER_STAR_ATTENTION"
@ -308,6 +310,8 @@ def create_py_executor(
has_draft_model_engine = spec_config.spec_dec_mode.has_draft_model()
has_spec_drafter = spec_config.spec_dec_mode.has_spec_drafter()
sparse_attention_config = llm_args.sparse_attention_config
# chunk_unit_size may be changed to 64 when using flash mla
attn_runtime_features = AttentionRuntimeFeatures(
chunked_prefill=enable_chunked_context,
@ -331,6 +335,7 @@ def create_py_executor(
attn_runtime_features=attn_runtime_features,
dist=dist,
spec_config=spec_config,
sparse_attention_config=sparse_attention_config,
lora_config=lora_config,
checkpoint_loader=checkpoint_loader,
)
@ -568,6 +573,7 @@ def create_py_executor(
kv_cache_config=kv_cache_config,
pytorch_backend_config=pytorch_backend_config,
speculative_config=spec_config,
sparse_attention_config=sparse_attention_config,
)
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
with mem_monitor.observe_creation_stage(

View File

@ -3,7 +3,7 @@ import enum
import math
from abc import ABC, abstractmethod
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
@ -165,6 +165,7 @@ class KVCacheManager(BaseResourceManager):
max_beam_width: int = 1,
is_draft: bool = False,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
**kwargs,
) -> None:
self.mapping = mapping
self.dtype = dtype
@ -543,6 +544,64 @@ class KVCacheManager(BaseResourceManager):
return get_size_in_bytes(cache_size // quant_vector_size,
scaling_factor_dtype)
# TODO: refactor get_cache_size_per_token and get_cache_bytes_per_token to use the same logic
@staticmethod
def get_cache_size_per_token(model_config: ModelConfigPython,
mapping: Mapping, **kwargs):
# get kv cache dtype bytes
mem_per_token = 2
quant_config = model_config.quant_config
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
):
mem_per_token = 1
# get num key value heads
config = model_config.pretrained_config
num_key_value_heads = getattr(config, 'num_key_value_heads',
config.num_attention_heads)
if isinstance(num_key_value_heads, Iterable):
num_key_value_heads = sum(num_key_value_heads) / len(
num_key_value_heads)
# get head dim
mla = hasattr(config, "kv_lora_rank")
if mla:
head_dim = config.kv_lora_rank + config.qk_rope_head_dim
kv_factor = 1
else:
tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size
head_dim = getattr(config, "head_dim", None)
if not isinstance(head_dim, int):
head_dim = config.hidden_size // config.num_attention_heads
head_dim = head_dim * num_key_value_heads // tp_size
kv_factor = 2
# provide at least 1 layer to prevent division by zero cache size
num_attention_layers = max(
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
mem_per_token *= num_attention_layers * head_dim
# K and V
mem_per_token *= kv_factor
return mem_per_token
def get_cache_bytes_per_token(self):
cache_size_per_token = self.kv_factor * sum(
self.num_kv_heads_per_layer) * self.head_dim
if self.dtype not in (DataType.FP8, DataType.HALF, DataType.BF16,
DataType.FLOAT, DataType.NVFP4):
raise ValueError(f'Cannot support {self.dtype} KV cache.')
cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token,
self.dtype)
if self.dtype == DataType.NVFP4:
cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes(
cache_size_per_token,
quant_vector_size=16,
scaling_factor_dtype=DataType.FP8)
return cache_size_bytes_per_token
def calculate_max_num_blocks(self,
kv_cache_config: KvCacheConfig,
head_dim: int,
@ -554,20 +613,7 @@ class KVCacheManager(BaseResourceManager):
if kv_cache_config.free_gpu_memory_fraction
is not None else 0.9)
cache_size_per_token = kv_factor * sum(
self.num_kv_heads_per_layer) * head_dim
if dtype not in (DataType.FP8, DataType.HALF, DataType.BF16,
DataType.FLOAT, DataType.NVFP4):
raise ValueError(f'Cannot support {dtype} KV cache.')
cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token,
dtype)
if dtype == DataType.NVFP4:
cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes(
cache_size_per_token,
quant_vector_size=16,
scaling_factor_dtype=DataType.FP8)
cache_size_bytes_per_token = self.get_cache_bytes_per_token()
free_mem, total_mem = torch.cuda.mem_get_info()

View File

@ -12,6 +12,7 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
RocketSparseAttentionConfig,
SaveHiddenStatesDecodingConfig, SchedulerConfig,
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
UserProvidedDecodingConfig)
@ -61,4 +62,5 @@ __all__ = [
'AttentionDpConfig',
'LoRARequest',
'SaveHiddenStatesDecodingConfig',
'RocketSparseAttentionConfig',
]

View File

@ -166,6 +166,63 @@ class CudaGraphConfig(StrictBaseModel):
return batch_sizes
class BaseSparseAttentionConfig(StrictBaseModel):
"""
Configuration for sparse attention.
"""
algorithm: Literal["rocket"] = Field(
default="rocket", description="The algorithm for sparse attention.")
@classmethod
def from_dict(cls, data: dict):
# dispatch to the correct sparse attention config
config_classes = {
"rocket": RocketSparseAttentionConfig,
}
algorithm = data.get("algorithm", None)
if algorithm is None:
raise ValueError(f"Sparse attention algorithm is required")
config_class = config_classes.get(algorithm.lower())
if config_class is None:
raise ValueError(f"Invalid algorithm: {algorithm}")
return config_class(**data)
def _check_fields(self):
pass
def supports_backend(self, backend: str) -> bool:
"""
Override if the speculation algorithm does not support
a subset of the possible backends.
"""
return True
class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
"""
Configuration for rocket sparse attention.
"""
window_size: Optional[int] = Field(
default=None, description="The window size for snap KV.")
kernel_size: Optional[int] = Field(
default=None, description="The kernel size for snap KV.")
topr: Optional[Union[int, float]] = Field(default=76, description="Top-r")
topk: Optional[int] = Field(default=128, description="Top-k")
prompt_budget: Optional[int] = Field(default=1266,
description="Prompt budget")
page_size: Optional[int] = Field(default=3, description="Page size")
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
class MoeConfig(StrictBaseModel):
"""
Configuration for MoE.
@ -1133,6 +1190,10 @@ SpeculativeConfig: TypeAlias = Optional[Union[
AutoDecodingConfig,
]]
SparseAttentionConfig: TypeAlias = Union[
RocketSparseAttentionConfig,
]
@PybindMirror.mirror_pybind_fields(_KvCacheConfig)
class KvCacheConfig(StrictBaseModel, PybindMirror):
@ -1510,6 +1571,12 @@ class BaseLlmArgs(StrictBaseModel):
description="Cache transceiver config.",
status="prototype")
# Sparse attention config
sparse_attention_config: Optional[SparseAttentionConfig] = Field(
default=None,
description="Sparse attention config.",
status="prototype")
# Speculative decoding parameters
speculative_config: SpeculativeConfig = Field(
default=None, description="Speculative decoding config.")

View File

@ -0,0 +1,80 @@
import json
import os
import pytest
from utils.llm_data import llm_models_root
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
@pytest.mark.parametrize("backend", ["pytorch"])
@pytest.mark.parametrize("model_name",
["llama-3.1-model/Llama-3.1-8B-Instruct"])
@pytest.mark.parametrize("attention_backend", ["VANILLA", "TRTLLM"])
def test_model(backend, model_name, attention_backend):
model_dir = str(llm_models_root() / model_name)
max_batch_size = 16
max_output_tokens = 128
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
enable_block_reuse=False)
sparse_attention_config = RocketSparseAttentionConfig(
window_size=32,
kernel_size=63,
prompt_budget=2048,
)
llm = LLM(
model=model_dir,
backend=backend,
kv_cache_config=kv_cache_config,
attn_backend=attention_backend,
sparse_attention_config=sparse_attention_config,
max_batch_size=max_batch_size,
max_seq_len=8192,
max_num_tokens=8192,
cuda_graph_config=
None, # sparse attention does not support cuda graph now
)
inputs, references = [], []
current_file = os.path.abspath(__file__)
current_dir = os.path.dirname(os.path.dirname(
os.path.dirname(current_file)))
input_file = f'{current_dir}/multi_gpu/test_star_attention_input.jsonl'
with open(input_file, 'r') as f:
for line in f:
sample = json.loads(line)
inputs.append({
'prompt':
sample['input_context'] + sample['input_query'],
})
references.append(sample['outputs'][0])
with llm:
outputs = llm.generate(
inputs,
use_tqdm=True,
sampling_params=SamplingParams(add_special_tokens=False,
max_tokens=max_output_tokens,
temperature=0.8,
top_p=0.95),
)
count = 0
for ref, ret in zip(references, outputs):
print(f"ret: {ret.outputs[0].text}")
print(f"ref: {ref}")
if ref not in ret.outputs[0].text:
print(f'reference {ref} is not in the output {ret.outputs[0].text}')
else:
count = count + 1
acc = count / len(outputs)
assert acc >= 0.9, 'accuracy test of rocketkv sparse attention failed'
if __name__ == '__main__':
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "VANILLA")
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "TRTLLM")

View File

@ -183,6 +183,10 @@ methods:
annotation: Optional[Literal["rpc", "ray"]]
default: null
status: prototype
sparse_attention_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.SparseAttentionConfig]
default: null
status: prototype
return_annotation: None
generate:
parameters: