mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8536][feat] Add the sparse attention framework and one use case--RocketKV support (#8086)
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com> Co-authored-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
parent
7291cdc422
commit
0d20a8fd61
@ -24,6 +24,7 @@
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
|
||||
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
|
||||
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/utils/debugUtils.h"
|
||||
@ -287,6 +288,9 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
|
||||
xqaParams.output_sf = generationsParams.context_buf_sf;
|
||||
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
|
||||
xqaParams.start_token_idx_sf = generationsParams.start_token_idx_sf;
|
||||
// Parameters for sparse attention
|
||||
xqaParams.sparse_params = mRuntimeSparseAttentionParams;
|
||||
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
|
||||
|
||||
// Cross attention parameters.
|
||||
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
|
||||
@ -813,7 +817,7 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
|
||||
}
|
||||
|
||||
size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t max_num_seq,
|
||||
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept
|
||||
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept
|
||||
{
|
||||
if (max_num_tokens == 0)
|
||||
{
|
||||
@ -909,11 +913,15 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
|
||||
size_t xqa_workspace_size = 0;
|
||||
if (mEnableXQA)
|
||||
{
|
||||
int const XQA_NUM_BUFFERS = 7;
|
||||
int const XQA_NUM_BUFFERS = 8;
|
||||
size_t xqa_workspaces[XQA_NUM_BUFFERS];
|
||||
size_t const cu_seqlens_size = sizeof(int) * (batch_beam + 1);
|
||||
size_t const cu_kv_seqlens_size = sizeof(int) * (batch_beam + 1);
|
||||
size_t const rotary_inv_freq_size = sizeof(float) * batch_beam * mRotaryEmbeddingDim / 2;
|
||||
// Two workspaces for sparse attention. One for the sequence lengths, and one for kv block offsets.
|
||||
size_t const sparse_attn_cache_size = useTllmGenSparseAttention()
|
||||
? sizeof(int) * (batch_beam + batch_beam * 2 * max_blocks_per_sequence) * mNumKVHeads
|
||||
: 0;
|
||||
xqa_workspaces[0] = cu_seqlens_size;
|
||||
xqa_workspaces[1] = cu_kv_seqlens_size;
|
||||
xqa_workspaces[2] = rotary_inv_freq_size;
|
||||
@ -922,7 +930,8 @@ size_t AttentionOp::getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32
|
||||
// Scales used for trtllm-gen kernels.
|
||||
xqa_workspaces[4] = sizeof(float) * 2;
|
||||
xqa_workspaces[5] = sizeof(float);
|
||||
xqa_workspaces[6] = mXqaDispatcher->getWorkspaceSize(
|
||||
xqa_workspaces[6] = sparse_attn_cache_size;
|
||||
xqa_workspaces[7] = mXqaDispatcher->getWorkspaceSize(
|
||||
std::min<uint32_t>(mSpecDecodingMaxGenerationLength * max_num_seq, max_num_tokens));
|
||||
xqa_workspace_size
|
||||
= tc::calculateTotalWorkspaceSize(xqa_workspaces, XQA_NUM_BUFFERS, mXqaDispatcher->getWorkspaceAlignment());
|
||||
@ -1647,6 +1656,10 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
preprocessingParams.spec_decoding_position_offsets = nullptr;
|
||||
preprocessingParams.logn_scaling = params.logn_scaling_ptr;
|
||||
|
||||
// Sparse KV write
|
||||
preprocessingParams.sparse_kv_indices = mRuntimeSparseAttentionParams.sparse_kv_indices;
|
||||
preprocessingParams.sparse_kv_offsets = mRuntimeSparseAttentionParams.sparse_kv_offsets;
|
||||
|
||||
// Scalars
|
||||
preprocessingParams.batch_size = params.batch_size;
|
||||
preprocessingParams.max_input_seq_len = params.input_seq_length;
|
||||
@ -1676,6 +1689,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
|
||||
preprocessingParams.rotary_vision_start = mVisionStart;
|
||||
preprocessingParams.rotary_vision_length = mVisionLength;
|
||||
preprocessingParams.is_last_chunk
|
||||
= !mAttentionChunkSize.has_value() || (params.input_seq_length == params.max_past_kv_length);
|
||||
|
||||
{
|
||||
std::string const beforeRopeStr = "ctx attention before RoPE at layer " + std::to_string(mLayerIdx);
|
||||
@ -1841,6 +1856,12 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
|
||||
gatherInBuffer, params, cu_q_seqlens, cu_cp_partial_seqlens, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
if (!mIsMLAEnabled) // Only for non-MLA attention
|
||||
{
|
||||
invokeKvCachePostprocessing(preprocessingParams, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/kernels/mlaKernels.h"
|
||||
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
|
||||
#include "tensorrt_llm/kernels/xqaDispatcher.h"
|
||||
#include <cassert>
|
||||
#include <set>
|
||||
@ -55,7 +56,7 @@ public:
|
||||
int32_t cross_kv_length = 0, int32_t max_num_tokens = 0) const noexcept;
|
||||
// total_num_seq is the sum of beam_width for multiple requests
|
||||
[[nodiscard]] size_t getWorkspaceSizeForGeneration(nvinfer1::DataType type, int32_t total_num_seq,
|
||||
int32_t max_attention_window_size, int32_t max_num_tokens) const noexcept;
|
||||
int32_t max_attention_window_size, int32_t max_num_tokens, int32_t max_blocks_per_sequence) const noexcept;
|
||||
|
||||
template <typename T>
|
||||
class EnqueueParams
|
||||
@ -156,14 +157,20 @@ public:
|
||||
ss << "max_cyclic_attention_window_size: " << this->max_cyclic_attention_window_size << std::endl;
|
||||
ss << "can_use_one_more_block: " << (this->can_use_one_more_block ? "true" : "false") << std::endl;
|
||||
ss << "sink_token_length: " << this->sink_token_length << std::endl;
|
||||
ss << "context_lengths: "
|
||||
<< *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32,
|
||||
runtime::ITensor::makeShape({batch_size})))
|
||||
<< std::endl;
|
||||
ss << "sequence_lengths: "
|
||||
<< *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32,
|
||||
runtime::ITensor::makeShape({batch_size})))
|
||||
<< std::endl;
|
||||
if (this->context_lengths && batch_size > 0)
|
||||
{
|
||||
ss << "context_lengths: "
|
||||
<< *(runtime::ITensor::wrap((void*) this->context_lengths, nvinfer1::DataType::kINT32,
|
||||
runtime::ITensor::makeShape({batch_size})))
|
||||
<< std::endl;
|
||||
}
|
||||
if (this->sequence_lengths && batch_size > 0)
|
||||
{
|
||||
ss << "sequence_lengths: "
|
||||
<< *(runtime::ITensor::wrap((void*) this->sequence_lengths, nvinfer1::DataType::kINT32,
|
||||
runtime::ITensor::makeShape({batch_size})))
|
||||
<< std::endl;
|
||||
}
|
||||
ss << "kv_scale_orig_quant: " << this->kv_scale_orig_quant << std::endl;
|
||||
ss << "kv_scale_quant_orig: " << this->kv_scale_quant_orig << std::endl;
|
||||
ss << "attention_output_orig_quant: " << this->attention_output_orig_quant << std::endl;
|
||||
@ -348,6 +355,16 @@ public:
|
||||
return mIsMLAEnabled;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool useSparseAttention() const
|
||||
{
|
||||
return mUseSparseAttention && mPagedKVCache && mEnableXQA;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool useTllmGenSparseAttention() const
|
||||
{
|
||||
return mUseTllmGenSparseAttention && useSparseAttention();
|
||||
}
|
||||
|
||||
[[nodiscard]] int smVersion() const
|
||||
{
|
||||
return mSM;
|
||||
@ -427,6 +444,8 @@ public:
|
||||
bool mIsMLAEnabled = false;
|
||||
bool mIsGenerationMLA = false;
|
||||
bool mUseGenFlashMLA = false;
|
||||
bool mUseSparseAttention = false;
|
||||
bool mUseTllmGenSparseAttention = false;
|
||||
tensorrt_llm::kernels::MlaMetaParams mMLAParams;
|
||||
int mCpSize = 1;
|
||||
int mCpRank = 0;
|
||||
@ -454,6 +473,8 @@ public:
|
||||
// Whether to fuse FP4 quant into attention kernel.
|
||||
bool mFuseFp4Quant = false;
|
||||
|
||||
kernels::SparseAttentionParams mRuntimeSparseAttentionParams;
|
||||
|
||||
// This is implementation details which we want to save when serializing, but not expose as
|
||||
// a plugin field or a constructor parameter
|
||||
int32_t mNbMultiBlockSemaphores = 0;
|
||||
@ -473,10 +494,11 @@ public:
|
||||
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
|
||||
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
|
||||
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
|
||||
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
|
||||
mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
|
||||
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
|
||||
mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
|
||||
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention, mUseTllmGenSparseAttention,
|
||||
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
|
||||
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
|
||||
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
|
||||
mRuntimeSparseAttentionParams.data(), mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
@ -233,6 +233,8 @@ struct XQALaunchParam
|
||||
float* bmm2_scale_ptr = nullptr;
|
||||
int32_t* semaphores = nullptr;
|
||||
void* scratch = nullptr;
|
||||
void* sparse_kv_block_offsets = nullptr;
|
||||
int32_t* sparse_seq_lengths = nullptr;
|
||||
};
|
||||
|
||||
// Setup launch params and ioScratch. ioScratch is for RoPE and output type conversion.
|
||||
@ -266,6 +268,9 @@ void buildXQALaunchParams(XQALaunchParam<KVCacheBuffer>& launchParams, void*& in
|
||||
const size_t cu_kv_seqlens_size = sizeof(int) * (batch_beam_size + 1);
|
||||
const size_t rotary_inv_freq_size = sizeof(float) * batch_beam_size * params.rotary_embedding_dim / 2;
|
||||
const size_t tokens_info_size = sizeof(int2) * params.total_num_input_tokens;
|
||||
const size_t kv_block_offsets_size
|
||||
= sizeof(int) * batch_beam_size * 2 * params.max_blocks_per_sequence * params.num_kv_heads;
|
||||
const size_t seq_lengths_size = sizeof(int) * batch_beam_size * params.num_kv_heads;
|
||||
launchParams.cu_seq_lens = reinterpret_cast<int*>(workspace);
|
||||
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, cu_seqlens_size);
|
||||
launchParams.cu_kv_seq_lens = reinterpret_cast<int*>(workspace);
|
||||
@ -281,6 +286,14 @@ void buildXQALaunchParams(XQALaunchParam<KVCacheBuffer>& launchParams, void*& in
|
||||
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm1_scale_size);
|
||||
launchParams.bmm2_scale_ptr = reinterpret_cast<float*>(workspace);
|
||||
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, bmm2_scale_size);
|
||||
// Used for block sparse attention
|
||||
if (params.use_sparse_attention)
|
||||
{
|
||||
launchParams.sparse_kv_block_offsets = reinterpret_cast<void*>(workspace);
|
||||
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, kv_block_offsets_size);
|
||||
launchParams.sparse_seq_lengths = reinterpret_cast<int*>(workspace);
|
||||
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, seq_lengths_size);
|
||||
}
|
||||
inputScratch = workspace;
|
||||
if (hasOutputScratch)
|
||||
{
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
|
||||
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
@ -109,6 +110,10 @@ struct XQAParams
|
||||
// for cross attention
|
||||
int32_t const* encoder_input_lengths = nullptr;
|
||||
|
||||
// sparse attention parameters
|
||||
SparseAttentionParams sparse_params;
|
||||
bool use_sparse_attention = false;
|
||||
|
||||
cudaStream_t stream = 0;
|
||||
|
||||
std::string toString() const
|
||||
@ -179,6 +184,8 @@ struct XQAParams
|
||||
<< "is_fp8_output :" << (is_fp8_output ? "true" : "false") << std ::endl
|
||||
<< "fp8_out_scale :" << fp8_out_scale << std ::endl
|
||||
<< "encoder_input_lengths: " << encoder_input_lengths << std::endl
|
||||
<< "sparse_params: " << sparse_params.toString() << std::endl
|
||||
<< "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl
|
||||
<< "stream :" << stream;
|
||||
|
||||
return ss.str();
|
||||
|
||||
132
cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu
Normal file
132
cpp/tensorrt_llm/kernels/sparseAttentionKernels.cu
Normal file
@ -0,0 +1,132 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
template <int THREADS_PER_BLOCK>
|
||||
__global__ void gatherKvPageOffsetsKernel(
|
||||
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
|
||||
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
|
||||
int32_t const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq]
|
||||
int32_t const* seq_lengths, // [batch_size]
|
||||
SparseAttentionParams const sparse_params, int32_t const batch_size, int32_t const tokens_per_page,
|
||||
int32_t const max_num_pages_per_seq)
|
||||
{
|
||||
// Each CUDA block processes one sequence from the batch for one head.
|
||||
int32_t const head_idx = blockIdx.x;
|
||||
int32_t const batch_idx = blockIdx.y;
|
||||
if (batch_idx >= batch_size)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Shared memory for reduction.
|
||||
__shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
|
||||
|
||||
// Get the range of sparse indices and the sequence length.
|
||||
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
|
||||
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
|
||||
int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size];
|
||||
int32_t const num_sparse_pages = end_offset - start_offset;
|
||||
int32_t const original_seq_len = seq_lengths[batch_idx];
|
||||
|
||||
// Get global sparse index.
|
||||
int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
|
||||
|
||||
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
|
||||
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
|
||||
size_t const dst_base_offset = (size_t) head_idx * batch_size * 2 * max_num_pages_per_seq + src_base_offset;
|
||||
|
||||
// Initialize the local max page index and number of valid pages.
|
||||
int32_t local_max_page_index = -1;
|
||||
int32_t local_num_valid_pages = 0;
|
||||
|
||||
// Perform the gather operation.
|
||||
for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
|
||||
{
|
||||
// Get the source idx and offset.
|
||||
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
|
||||
if (src_idx < 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Update the local max page index.
|
||||
local_max_page_index = max(local_max_page_index, src_idx);
|
||||
local_num_valid_pages++;
|
||||
|
||||
// Get the source and destination offsets.
|
||||
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
|
||||
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
|
||||
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
|
||||
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
|
||||
|
||||
// Perform the gather operation: read from the sparse location and write to the dense location.
|
||||
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
|
||||
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
|
||||
}
|
||||
|
||||
// Reduce the local max page indices and number of valid pages.
|
||||
Pair local_pair = {local_max_page_index, local_num_valid_pages};
|
||||
Pair result = cub::BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage).Reduce(local_pair, PairReduceOp());
|
||||
|
||||
// Update sequence length for this head and batch.
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int32_t const max_page_index = result.max_val;
|
||||
int32_t const num_valid_pages = result.sum_val;
|
||||
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
|
||||
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
|
||||
if (num_valid_pages > 0)
|
||||
{
|
||||
int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
|
||||
int32_t seq_len_remain = original_seq_len % tokens_per_page;
|
||||
if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0)
|
||||
{
|
||||
seq_len += tokens_per_page - seq_len_remain;
|
||||
}
|
||||
output_seq_lengths[seq_len_offset] = seq_len;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_seq_lengths[seq_len_offset] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Host-side launcher function
|
||||
void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_seq_lengths,
|
||||
int32_t const* kv_page_offsets, int32_t const* seq_lengths, SparseAttentionParams const sparse_params,
|
||||
int32_t const batch_size, int32_t const num_head_kv, int32_t const tokens_per_page,
|
||||
int32_t const max_num_pages_per_seq, cudaStream_t stream)
|
||||
{
|
||||
// The grid.
|
||||
dim3 grid(num_head_kv, batch_size, 1);
|
||||
// The block.
|
||||
dim3 block(256, 1, 1);
|
||||
// Shared memory size.
|
||||
size_t smem_size = sizeof(Pair) * 256;
|
||||
|
||||
// Launch the kernel.
|
||||
gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
|
||||
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
|
||||
}
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
81
cpp/tensorrt_llm/kernels/sparseAttentionKernels.h
Normal file
81
cpp/tensorrt_llm/kernels/sparseAttentionKernels.h
Normal file
@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
struct SparseAttentionParams
|
||||
{
|
||||
int32_t* sparse_kv_indices{nullptr}; // [num_kv_heads, num_sparse_kv_indices]
|
||||
int32_t* sparse_attn_indices{nullptr}; // [num_kv_heads, num_sparse_attn_indices]
|
||||
int32_t* sparse_kv_offsets{nullptr}; // [num_contexts + 1]
|
||||
int32_t* sparse_attn_offsets{nullptr}; // [num_generations + 1]
|
||||
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "sparse_kv_indices: " << this->sparse_kv_indices << std::endl
|
||||
<< "sparse_attn_indices: " << this->sparse_attn_indices << std::endl
|
||||
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
|
||||
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
auto data() const
|
||||
{
|
||||
return std::make_tuple(sparse_kv_indices, sparse_attn_indices, sparse_kv_offsets, sparse_attn_offsets);
|
||||
}
|
||||
};
|
||||
|
||||
struct Pair
|
||||
{
|
||||
int32_t max_val;
|
||||
int32_t sum_val;
|
||||
};
|
||||
|
||||
struct PairReduceOp
|
||||
{
|
||||
#if defined(__CUDACC__)
|
||||
inline __device__
|
||||
#endif
|
||||
Pair
|
||||
operator()(Pair const& a, Pair const& b) const
|
||||
{
|
||||
Pair result;
|
||||
result.max_val = a.max_val > b.max_val ? a.max_val : b.max_val;
|
||||
result.sum_val = a.sum_val + b.sum_val;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
|
||||
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
|
||||
int32_t const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq]
|
||||
int32_t const* seq_lengths, // [batch_size]
|
||||
SparseAttentionParams const sparse_params, int32_t const batch_size, int32_t const num_head_kv,
|
||||
int32_t const tokens_per_page, int32_t const max_num_pages_per_seq, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -207,8 +207,6 @@ public:
|
||||
|
||||
// Prepare the kernel parameters.
|
||||
auto kernelParams = KernelParams::setKernelParams(params, kernelMeta, maxNumCtasQ, maxNumCtasKv);
|
||||
// TODO: set the block sparse attention flag.
|
||||
kernelParams.mUseBlockSparseAttention = false;
|
||||
|
||||
// Prepare kernel parameters list for cuLaunchKernelEx.
|
||||
void* kernelParamsList[] = {&kernelParams};
|
||||
|
||||
@ -148,6 +148,10 @@ struct QKVPreprocessingParams
|
||||
int const* cu_seq_lens{nullptr};
|
||||
// list of cumulative KV sequence lengths, of shape {batch_size + 1}, used by cross attention only.
|
||||
int const* cu_kv_seq_lens{nullptr};
|
||||
// list of cumulative length of sparse KV indices, of shape {batch_size + 1}
|
||||
int const* sparse_kv_offsets{nullptr};
|
||||
// list of sparse KV indices for writing to KV cache, of shape {num_kv_heads, num_sparse_kv_indices}
|
||||
int const* sparse_kv_indices{nullptr};
|
||||
// inverse frequencies (angle raised at various powers) from the RoPE formula
|
||||
// shape of {batch_size , rotaryEmbeddingDim / 2}
|
||||
float const* rotary_embedding_inv_freq{nullptr};
|
||||
@ -167,6 +171,7 @@ struct QKVPreprocessingParams
|
||||
int sink_token_len{0};
|
||||
int token_num{0};
|
||||
bool remove_padding{true};
|
||||
bool is_last_chunk{true};
|
||||
bool cross_attention{false};
|
||||
int head_num{0};
|
||||
int kv_head_num{0};
|
||||
@ -216,24 +221,48 @@ struct QKVPreprocessingParams
|
||||
ss << "kv_cache_block_scales_buffer: " << kv_cache_block_scales_buffer.data << std::endl;
|
||||
ss << "qkv_bias: " << qkv_bias << std::endl;
|
||||
ss << "tokens_info: " << tokens_info << std::endl;
|
||||
ss << "seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
ss << "cache_seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) cache_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
ss << "encoder_seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) encoder_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
ss << "cu_seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) cu_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
ss << "cu_kv_seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) cu_kv_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
ss << "rotary_embedding_inv_freq: "
|
||||
<< *(runtime::ITensor::wrap((void*) rotary_embedding_inv_freq, nvinfer1::DataType::kFLOAT,
|
||||
runtime::ITensor::makeShape({batch_size, rotary_embedding_dim / 2})));
|
||||
if (seq_lens && batch_size > 0)
|
||||
{
|
||||
ss << "seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
}
|
||||
if (cache_seq_lens && batch_size > 0)
|
||||
{
|
||||
ss << "cache_seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) cache_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
}
|
||||
if (encoder_seq_lens && batch_size > 0)
|
||||
{
|
||||
ss << "encoder_seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) encoder_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
}
|
||||
if (cu_seq_lens && batch_size > 0)
|
||||
{
|
||||
ss << "cu_seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) cu_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
}
|
||||
if (cu_kv_seq_lens && batch_size > 0)
|
||||
{
|
||||
ss << "cu_kv_seq_lens: "
|
||||
<< *(runtime::ITensor::wrap(
|
||||
(void*) cu_kv_seq_lens, nvinfer1::DataType::kINT32, runtime::ITensor::makeShape({batch_size})));
|
||||
}
|
||||
if (sparse_kv_offsets)
|
||||
{
|
||||
ss << "sparse_kv_offsets: "
|
||||
<< *(runtime::ITensor::wrap((void*) sparse_kv_offsets, nvinfer1::DataType::kINT32,
|
||||
runtime::ITensor::makeShape({batch_size + 1})));
|
||||
}
|
||||
if (rotary_embedding_inv_freq && batch_size > 0 && rotary_embedding_dim > 0)
|
||||
{
|
||||
ss << "rotary_embedding_inv_freq: "
|
||||
<< *(runtime::ITensor::wrap((void*) rotary_embedding_inv_freq, nvinfer1::DataType::kFLOAT,
|
||||
runtime::ITensor::makeShape({batch_size, rotary_embedding_dim / 2})));
|
||||
}
|
||||
ss << "rotary_coef_cache_buffer: " << rotary_coef_cache_buffer << std::endl;
|
||||
ss << "qkv_scale_orig_quant: " << qkv_scale_orig_quant << std::endl;
|
||||
ss << "spec_decoding_position_offsets: " << spec_decoding_position_offsets << std::endl;
|
||||
@ -244,6 +273,7 @@ struct QKVPreprocessingParams
|
||||
ss << "sink_token_len: " << sink_token_len << std::endl;
|
||||
ss << "token_num: " << token_num << std::endl;
|
||||
ss << "remove_padding: " << remove_padding << std::endl;
|
||||
ss << "is_last_chunk: " << is_last_chunk << std::endl;
|
||||
ss << "cross_attention: " << cross_attention << std::endl;
|
||||
ss << "head_num: " << head_num << std::endl;
|
||||
ss << "kv_head_num: " << kv_head_num << std::endl;
|
||||
@ -362,23 +392,36 @@ void invokeQKVPreprocessing(QKVPreprocessingParams<T, KVCacheBuffer> params, cud
|
||||
template <typename T, typename T_cache, typename KVCacheBuffer>
|
||||
void invokeUpdateCyclicKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename T_cache, typename KVCacheBuffer>
|
||||
void invokeUpdateSparseKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream);
|
||||
|
||||
// Debug function to test basic parameter access
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeDebugSparseKvCacheParams(
|
||||
QKVPreprocessingParams<T, KVCacheBuffer> params, int* debug_output, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void invokeKvCachePostprocessing(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream)
|
||||
{
|
||||
params.setCommonParameters();
|
||||
if (params.cache_type == KvCacheDataType::INT8)
|
||||
|
||||
// handle sparse KV cache update if needed
|
||||
if (params.sparse_kv_indices != nullptr && params.sparse_kv_offsets != nullptr && params.is_last_chunk)
|
||||
{
|
||||
invokeUpdateCyclicKvCacheAfterFmha<T, int8_t, KVCacheBuffer>(params, stream);
|
||||
}
|
||||
if (params.cache_type == KvCacheDataType::INT8)
|
||||
{
|
||||
invokeUpdateSparseKvCacheAfterFmha<T, int8_t, KVCacheBuffer>(params, stream);
|
||||
}
|
||||
#ifdef ENABLE_FP8
|
||||
else if (params.cache_type == KvCacheDataType::FP8)
|
||||
{
|
||||
invokeUpdateCyclicKvCacheAfterFmha<T, __nv_fp8_e4m3, KVCacheBuffer>(params, stream);
|
||||
}
|
||||
else if (params.cache_type == KvCacheDataType::FP8)
|
||||
{
|
||||
invokeUpdateSparseKvCacheAfterFmha<T, __nv_fp8_e4m3, KVCacheBuffer>(params, stream);
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
else
|
||||
{
|
||||
invokeUpdateCyclicKvCacheAfterFmha<T, T, KVCacheBuffer>(params, stream);
|
||||
else
|
||||
{
|
||||
invokeUpdateSparseKvCacheAfterFmha<T, T, KVCacheBuffer>(params, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1709,6 +1709,130 @@ void invokeUpdateCyclicKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename TCache, int BLOCK_SIZE, int Dh, typename KVCacheBuffer>
|
||||
__global__ __launch_bounds__(BLOCK_SIZE) void updateSparseKvCacheAfterFmha(
|
||||
QKVPreprocessingParams<T, KVCacheBuffer> params)
|
||||
{
|
||||
// The number of 16B vectors per head size in the kv cache.
|
||||
constexpr int VECS_PER_HEAD = Dh * sizeof(TCache) / 16;
|
||||
static_assert(BLOCK_SIZE % VECS_PER_HEAD == 0, "Kernel block should be able to handle entire heads.");
|
||||
|
||||
int const batch_idx = blockIdx.z;
|
||||
int const kv_head_idx = blockIdx.y;
|
||||
|
||||
int const total_num_sparse_kv_tokens = params.sparse_kv_offsets[params.batch_size];
|
||||
|
||||
int const sparse_start_idx = params.sparse_kv_offsets[batch_idx];
|
||||
int const sparse_end_idx = params.sparse_kv_offsets[batch_idx + 1];
|
||||
int const num_sparse_tokens = sparse_end_idx - sparse_start_idx;
|
||||
|
||||
int const tokens_per_block = blockDim.y;
|
||||
int const vecs_per_block = blockDim.x;
|
||||
|
||||
extern __shared__ uint4 smem[];
|
||||
uint4* k_smem = smem;
|
||||
uint4* v_smem = k_smem + tokens_per_block * VECS_PER_HEAD;
|
||||
|
||||
for (int token_block_offset = 0; token_block_offset < num_sparse_tokens; token_block_offset += tokens_per_block)
|
||||
{
|
||||
int const sparse_token_offset = token_block_offset + threadIdx.y;
|
||||
|
||||
if (sparse_token_offset < num_sparse_tokens)
|
||||
{
|
||||
int const global_sparse_idx = sparse_start_idx + sparse_token_offset;
|
||||
int const sparse_idx_offset = kv_head_idx * total_num_sparse_kv_tokens + global_sparse_idx;
|
||||
|
||||
int const src_token_idx = params.sparse_kv_indices[sparse_idx_offset];
|
||||
|
||||
void* src_k_ptr = params.kv_cache_buffer.getKBlockPtr(batch_idx, src_token_idx);
|
||||
void* src_v_ptr = params.kv_cache_buffer.getVBlockPtr(batch_idx, src_token_idx);
|
||||
auto const src_k_block_ptr = reinterpret_cast<uint4*>(src_k_ptr);
|
||||
auto const src_v_block_ptr = reinterpret_cast<uint4*>(src_v_ptr);
|
||||
|
||||
for (int head_vec_idx = threadIdx.x; head_vec_idx < VECS_PER_HEAD; head_vec_idx += vecs_per_block)
|
||||
{
|
||||
auto const src_k_vec_idx
|
||||
= params.kv_cache_buffer.getKVLocalIdx(src_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx);
|
||||
auto const src_v_vec_idx
|
||||
= params.kv_cache_buffer.getKVLocalIdx(src_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx);
|
||||
|
||||
k_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx] = src_k_block_ptr[src_k_vec_idx];
|
||||
v_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx] = src_v_block_ptr[src_v_vec_idx];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (sparse_token_offset < num_sparse_tokens)
|
||||
{
|
||||
int const global_sparse_idx = sparse_start_idx + sparse_token_offset;
|
||||
int const sparse_idx_offset = kv_head_idx * total_num_sparse_kv_tokens + global_sparse_idx;
|
||||
|
||||
int const src_token_idx = params.sparse_kv_indices[sparse_idx_offset];
|
||||
int const dst_token_idx = sparse_token_offset;
|
||||
|
||||
if (src_token_idx != dst_token_idx)
|
||||
{
|
||||
void* dst_k_ptr = params.kv_cache_buffer.getKBlockPtr(batch_idx, dst_token_idx);
|
||||
void* dst_v_ptr = params.kv_cache_buffer.getVBlockPtr(batch_idx, dst_token_idx);
|
||||
auto const dst_k_block_ptr = reinterpret_cast<uint4*>(dst_k_ptr);
|
||||
auto const dst_v_block_ptr = reinterpret_cast<uint4*>(dst_v_ptr);
|
||||
|
||||
for (int head_vec_idx = threadIdx.x; head_vec_idx < VECS_PER_HEAD; head_vec_idx += vecs_per_block)
|
||||
{
|
||||
auto const dst_k_vec_idx
|
||||
= params.kv_cache_buffer.getKVLocalIdx(dst_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx);
|
||||
auto const dst_v_vec_idx
|
||||
= params.kv_cache_buffer.getKVLocalIdx(dst_token_idx, kv_head_idx, VECS_PER_HEAD, head_vec_idx);
|
||||
dst_k_block_ptr[dst_k_vec_idx] = k_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx];
|
||||
dst_v_block_ptr[dst_v_vec_idx] = v_smem[threadIdx.y * VECS_PER_HEAD + head_vec_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Dh, typename T, typename TCache, typename KVCacheBuffer>
|
||||
void kernelSparseDispatchHeadSize(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream)
|
||||
{
|
||||
constexpr int VECS_PER_HEAD = Dh * sizeof(TCache) / 16;
|
||||
constexpr int BLOCK_SIZE = 1024;
|
||||
dim3 block(32, 32); // x: head vectors, y: tokens
|
||||
|
||||
int smem_size = 2 * block.y * VECS_PER_HEAD * sizeof(uint4);
|
||||
|
||||
// grid.x is always 1 to avoid data races
|
||||
dim3 grid(1, params.kv_head_num, params.batch_size);
|
||||
|
||||
updateSparseKvCacheAfterFmha<T, TCache, BLOCK_SIZE, Dh, KVCacheBuffer><<<grid, block, smem_size, stream>>>(params);
|
||||
}
|
||||
|
||||
template <typename T, typename TCache, typename KVCacheBuffer>
|
||||
void invokeUpdateSparseKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream)
|
||||
{
|
||||
if (params.sparse_kv_indices == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
switch (params.size_per_head)
|
||||
{
|
||||
case 16: kernelSparseDispatchHeadSize<16, T, TCache, KVCacheBuffer>(params, stream); break;
|
||||
case 32: kernelSparseDispatchHeadSize<32, T, TCache, KVCacheBuffer>(params, stream); break;
|
||||
case 64: kernelSparseDispatchHeadSize<64, T, TCache, KVCacheBuffer>(params, stream); break;
|
||||
case 128: kernelSparseDispatchHeadSize<128, T, TCache, KVCacheBuffer>(params, stream); break;
|
||||
case 256: kernelSparseDispatchHeadSize<256, T, TCache, KVCacheBuffer>(params, stream); break;
|
||||
default:
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
false, "updateSparseKvCacheAfterFmha kernel doesn't support head size = %d", params.size_per_head);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define INSTANTIATE_ATTENTION_INPUT_PROCESSING(T, TCache, KVCacheBuffer) \
|
||||
template void invokeApplyBiasRopeUpdateKVCacheDispatch<T, TCache, KVCacheBuffer>( \
|
||||
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream);
|
||||
@ -1717,9 +1841,10 @@ void invokeUpdateCyclicKvCacheAfterFmha(QKVPreprocessingParams<T, KVCacheBuffer>
|
||||
template void invokeApplyBiasRopeUpdateKVCacheDispatch<T, TCache, KVCacheBuffer>( \
|
||||
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream); \
|
||||
template void invokeUpdateCyclicKvCacheAfterFmha<T, TCache, KVCacheBuffer>( \
|
||||
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream); \
|
||||
template void invokeUpdateSparseKvCacheAfterFmha<T, TCache, KVCacheBuffer>( \
|
||||
QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStream_t stream); \
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "xqaDispatcher.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
|
||||
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
|
||||
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
||||
#include <cstdint>
|
||||
|
||||
@ -404,9 +405,15 @@ void XqaDispatcher::runImpl(
|
||||
// Otherwise, always enable the persistent scheduler for better performance.
|
||||
tllmRunnerParams.mTileScheduler = params.multi_block_mode ? TileScheduler::Static : TileScheduler::Persistent;
|
||||
|
||||
// The sequence lengths for K/V.
|
||||
tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;
|
||||
|
||||
// Q buffer.
|
||||
tllmRunnerParams.qPtr = xqa_q_input_ptr;
|
||||
// KV buffer
|
||||
|
||||
// Use block sparse attention.
|
||||
tllmRunnerParams.mUseBlockSparseAttention = false;
|
||||
|
||||
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
|
||||
{
|
||||
// Paged KV
|
||||
@ -416,9 +423,24 @@ void XqaDispatcher::runImpl(
|
||||
tllmRunnerParams.kvPageIdxPtr = reinterpret_cast<KVCacheIndex::UnderlyingType const*>(kv_cache_buffer.data);
|
||||
tllmRunnerParams.mMaxNumPagesPerSeqKv = kv_cache_buffer.mMaxBlocksPerSeq;
|
||||
tllmRunnerParams.mNumTokensPerPage = kv_cache_buffer.mTokensPerBlock;
|
||||
|
||||
// Gather kv page offsets for sparse attention.
|
||||
if (params.use_sparse_attention)
|
||||
{
|
||||
invokeGatherKvPageOffsets(reinterpret_cast<int32_t*>(launchParams.sparse_kv_block_offsets),
|
||||
launchParams.sparse_seq_lengths, reinterpret_cast<int32_t const*>(kv_cache_buffer.data),
|
||||
params.sequence_lengths, params.sparse_params, batch_beam_size, num_kv_heads,
|
||||
kv_cache_buffer.mTokensPerBlock, kv_cache_buffer.mMaxBlocksPerSeq, params.stream);
|
||||
sync_check_cuda_error(params.stream);
|
||||
tllmRunnerParams.seqLensKvPtr = launchParams.sparse_seq_lengths;
|
||||
tllmRunnerParams.kvPageIdxPtr
|
||||
= reinterpret_cast<KVCacheIndex::UnderlyingType const*>(launchParams.sparse_kv_block_offsets);
|
||||
tllmRunnerParams.mUseBlockSparseAttention = true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!params.use_sparse_attention, "Sparse attention is not supported for KVLinearBuffer.");
|
||||
static_assert(std::is_same_v<KVCacheBuffer, KVLinearBuffer>);
|
||||
// Contiguous KV
|
||||
tllmRunnerParams.mQkvLayout = QkvLayout::ContiguousKv;
|
||||
@ -437,8 +459,6 @@ void XqaDispatcher::runImpl(
|
||||
tllmRunnerParams.scaleSoftmaxLog2Ptr
|
||||
= reinterpret_cast<float const*>(launchParams.bmm1_scale_ptr + kIdxScaleSoftmaxLog2Ptr);
|
||||
tllmRunnerParams.oSfScalePtr = params.fp4_out_sf_scale;
|
||||
// The sequence lengths for K/V.
|
||||
tllmRunnerParams.seqLensKvPtr = params.cross_attention ? params.encoder_input_lengths : params.sequence_lengths;
|
||||
|
||||
tllmRunnerParams.oPtr = params.output;
|
||||
tllmRunnerParams.oSfPtr = params.output_sf;
|
||||
|
||||
@ -55,7 +55,7 @@ void initBindings(nb::module_& m)
|
||||
nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt,
|
||||
nb::arg("mla_tensor_params"), nb::arg("attention_chunk_size") = std::nullopt,
|
||||
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
|
||||
nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation",
|
||||
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_attention_params"), "Multi-head attention operation",
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
}
|
||||
} // namespace tensorrt_llm::nanobind::thop
|
||||
|
||||
@ -572,11 +572,14 @@ size_t GPTAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* in
|
||||
= isCrossAttention() ? cross_kv_length : (useKVCache() ? inputs[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0);
|
||||
int const max_num_tokens
|
||||
= mRemovePadding ? inputs[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0] : max_num_seq * max_context_length;
|
||||
int const max_blocks_per_sequence
|
||||
= (useKVCache() && mPagedKVCache) ? inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims.d[3] : 0;
|
||||
|
||||
size_t const context_workspace_size
|
||||
= getWorkspaceSizeForContext(type, max_num_seq, max_context_length, cross_kv_length, max_num_tokens);
|
||||
|
||||
size_t const generation_workspace_size
|
||||
= getWorkspaceSizeForGeneration(type, max_num_seq, max_kv_cache_length, max_num_tokens);
|
||||
size_t const generation_workspace_size = getWorkspaceSizeForGeneration(
|
||||
type, max_num_seq, max_kv_cache_length, max_num_tokens, max_blocks_per_sequence);
|
||||
|
||||
size_t attention_input_workspace_size = 0;
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ void initBindings(pybind11::module_& m)
|
||||
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
|
||||
py::arg("mla_tensor_params"), py::arg("attention_chunk_size") = std::nullopt,
|
||||
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
|
||||
py::arg("spec_decoding_tensor_params"), "Multi-head attention operation",
|
||||
py::arg("spec_decoding_tensor_params"), py::arg("sparse_attention_params"), "Multi-head attention operation",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::thop
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/kernels/mlaKernels.h"
|
||||
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/debugUtils.h"
|
||||
#include "tensorrt_llm/thop/attentionOp.h"
|
||||
@ -38,6 +39,7 @@ namespace trtllm::attention
|
||||
{
|
||||
using tensorrt_llm::kernels::KVBlockArray;
|
||||
using tensorrt_llm::kernels::MlaParams;
|
||||
using tensorrt_llm::kernels::SparseAttentionParams;
|
||||
|
||||
enum class AttentionInputType : int8_t
|
||||
{
|
||||
@ -62,7 +64,7 @@ public:
|
||||
virtual ~RunnerBase() = default;
|
||||
virtual void prepare(AttentionOp& op) const = 0;
|
||||
virtual int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size,
|
||||
int const num_gen_tokens) const
|
||||
int const num_gen_tokens, int const max_blocks_per_sequence) const
|
||||
= 0;
|
||||
// typically, we use single qkv input, but for context MLA, we use separate qkv inputs
|
||||
virtual void run(AttentionOp& op, bool const is_context, int32_t const seq_offset, int32_t const num_seqs,
|
||||
@ -82,7 +84,9 @@ public:
|
||||
std::vector<std::optional<torch::Tensor>> mla_tensor_params,
|
||||
torch::optional<torch::Tensor> softmax_stats_tensor,
|
||||
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
torch::optional<torch::Tensor> attention_sinks) const
|
||||
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
|
||||
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
|
||||
torch::optional<torch::Tensor> sparse_attn_offsets) const
|
||||
= 0;
|
||||
};
|
||||
|
||||
@ -110,12 +114,12 @@ public:
|
||||
}
|
||||
|
||||
int64_t getWorkspaceSize(AttentionOp const& op, int const num_tokens, int const max_attention_window_size,
|
||||
int const num_gen_tokens) const override
|
||||
int const num_gen_tokens, int const max_blocks_per_sequence) const override
|
||||
{
|
||||
size_t const context_workspace_size
|
||||
= op.getWorkspaceSizeForContext(op.mType, max_num_requests, op.mMaxContextLength, 0, num_tokens);
|
||||
size_t const generation_workspace_size
|
||||
= op.getWorkspaceSizeForGeneration(op.mType, max_num_requests, max_attention_window_size, num_gen_tokens);
|
||||
size_t const generation_workspace_size = op.getWorkspaceSizeForGeneration(
|
||||
op.mType, max_num_requests, max_attention_window_size, num_gen_tokens, max_blocks_per_sequence);
|
||||
|
||||
return std::max(context_workspace_size, generation_workspace_size);
|
||||
}
|
||||
@ -137,7 +141,9 @@ public:
|
||||
std::vector<std::optional<torch::Tensor>> mla_tensor_params,
|
||||
torch::optional<torch::Tensor> softmax_stats_tensor,
|
||||
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
torch::optional<torch::Tensor> attention_sinks) const override
|
||||
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
|
||||
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
|
||||
torch::optional<torch::Tensor> sparse_attn_offsets) const override
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
|
||||
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
|
||||
@ -221,6 +227,22 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare sparse attention parameters
|
||||
if (is_context)
|
||||
{
|
||||
op.mRuntimeSparseAttentionParams.sparse_kv_indices
|
||||
= sparse_kv_indices.has_value() ? sparse_kv_indices.value().data_ptr<int32_t>() : nullptr;
|
||||
op.mRuntimeSparseAttentionParams.sparse_kv_offsets
|
||||
= sparse_kv_offsets.has_value() ? sparse_kv_offsets.value().data_ptr<int32_t>() : nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
op.mRuntimeSparseAttentionParams.sparse_attn_indices
|
||||
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
|
||||
op.mRuntimeSparseAttentionParams.sparse_attn_offsets
|
||||
= sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
|
||||
}
|
||||
|
||||
int const* context_lengths_ptr = context_lengths.slice(0, seq_offset).data_ptr<int>();
|
||||
int const* sequence_lengths_ptr = sequence_length.slice(0, seq_offset).data_ptr<int>();
|
||||
// Note we still need context length during generation for MMHA optimization.
|
||||
@ -518,8 +540,16 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
|
||||
std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t> attention_chunk_size,
|
||||
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
|
||||
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params)
|
||||
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
std::vector<std::optional<torch::Tensor>> sparse_attention_params)
|
||||
{
|
||||
// Decompress sparse attention parameters
|
||||
TORCH_CHECK(sparse_attention_params.size() == 4, "Expected 4 sparse attention parameters");
|
||||
torch::optional<torch::Tensor> sparse_kv_indices = sparse_attention_params[0];
|
||||
torch::optional<torch::Tensor> sparse_kv_offsets = sparse_attention_params[1];
|
||||
torch::optional<torch::Tensor> sparse_attn_indices = sparse_attention_params[2];
|
||||
torch::optional<torch::Tensor> sparse_attn_offsets = sparse_attention_params[3];
|
||||
|
||||
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
|
||||
// Use these tensors to infer if the attention is using KV cache
|
||||
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value()
|
||||
@ -633,6 +663,18 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding
|
||||
op->mIsSpecDecTree = spec_decoding_bool_params[2]; // is_spec_dec_tree
|
||||
|
||||
op->mUseSparseAttention = false;
|
||||
op->mUseTllmGenSparseAttention = false;
|
||||
if ((sparse_kv_indices.has_value() && sparse_kv_indices.value().numel() > 0)
|
||||
|| (sparse_attn_indices.has_value() && sparse_attn_indices.value().numel() > 0))
|
||||
{
|
||||
op->mUseSparseAttention = true;
|
||||
if (sparse_attn_indices.has_value() && sparse_attn_indices.value().numel() > 0)
|
||||
{
|
||||
op->mUseTllmGenSparseAttention = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_mla_enable)
|
||||
{
|
||||
// MLA does not support NVFP4 output yet.
|
||||
@ -718,7 +760,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
|
||||
int32_t const max_attention_window_size
|
||||
= beam_width == 1 ? attention_window_size : cache_indirection.value().size(2);
|
||||
int64_t const workspace_size = runner->getWorkspaceSize(*op, num_tokens, max_attention_window_size, num_gen_tokens);
|
||||
int32_t const max_blocks_per_sequence
|
||||
= use_kv_cache && kv_cache_block_offsets.has_value() ? kv_cache_block_offsets.value().size(-1) : 0;
|
||||
int64_t const workspace_size
|
||||
= runner->getWorkspaceSize(*op, num_tokens, max_attention_window_size, num_gen_tokens, max_blocks_per_sequence);
|
||||
TLLM_LOG_TRACE("Expected workspace size is %ld bytes", workspace_size);
|
||||
|
||||
if (workspace_size >= (16l << 30))
|
||||
@ -760,7 +805,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
|
||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
||||
attention_sinks);
|
||||
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets);
|
||||
}
|
||||
|
||||
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
|
||||
@ -777,7 +822,7 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
|
||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
||||
attention_sinks);
|
||||
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
|
||||
|
||||
@ -57,9 +57,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,
|
||||
std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
|
||||
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
|
||||
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
|
||||
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
|
||||
std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t> attention_chunk_size,
|
||||
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
|
||||
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params);
|
||||
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
std::vector<std::optional<torch::Tensor>> sparse_attention_params);
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
@ -40,6 +40,7 @@ add_gtest(smoothQuantKernelTest smoothQuant/smoothQuantKernelTest.cpp)
|
||||
add_gtest(stopCriteriaKernelsTest stopCriteriaKernelsTest.cpp)
|
||||
add_gtest(weightOnlyKernelTest weightOnly/weightOnlyKernelTest.cpp)
|
||||
add_gtest(mlaPreprocessTest mlaPreprocessTest.cu)
|
||||
add_gtest(sparseAttentionKernelsTest sparseAttentionKernelsTest.cpp)
|
||||
|
||||
add_gtest(cudaCoreGemmKernelTest cudaCoreGemm/cudaCoreGemmKernelTest.cpp)
|
||||
|
||||
@ -90,3 +91,4 @@ add_gtest(routingKernelsTest "${ROUTING_KERNEL_TEST_SRC}")
|
||||
add_gtest(moeLoadBalanceKernelTest moeLoadBalanceKernelTest.cpp)
|
||||
|
||||
add_gtest(eaglePackDataTest eaglePackDataTest.cpp)
|
||||
add_gtest(sparseKvCacheTest sparseKvCacheTest.cu)
|
||||
|
||||
191
cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp
Normal file
191
cpp/tests/unit_tests/kernels/sparseAttentionKernelsTest.cpp
Normal file
@ -0,0 +1,191 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "tensorrt_llm/kernels/sparseAttentionKernels.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace
|
||||
{
|
||||
class sparseAttentionKernelsTest : public ::testing::Test
|
||||
{
|
||||
public:
|
||||
void SetUp() override
|
||||
{
|
||||
mStream = std::make_shared<CudaStream>();
|
||||
mBufferManager = std::make_shared<BufferManager>(mStream);
|
||||
}
|
||||
|
||||
void TearDown() override {}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<CudaStream> mStream;
|
||||
std::shared_ptr<BufferManager> mBufferManager;
|
||||
};
|
||||
|
||||
TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
||||
{
|
||||
// Test parameters
|
||||
constexpr int max_batch_size = 4;
|
||||
constexpr int batch_size = 2;
|
||||
constexpr int num_head_kv = 4;
|
||||
constexpr int max_num_pages_per_seq = 8;
|
||||
constexpr int tokens_per_page = 64;
|
||||
constexpr int total_sparse_pages = max_batch_size * max_num_pages_per_seq; // Total sparse pages across all batches
|
||||
|
||||
// Create input buffers
|
||||
auto kv_page_offsets
|
||||
= mBufferManager->gpu(ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
||||
auto seq_lengths = mBufferManager->gpu(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
|
||||
auto sparse_indices
|
||||
= mBufferManager->gpu(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
|
||||
auto sparse_indices_offsets = mBufferManager->gpu(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
|
||||
|
||||
// Create output buffers
|
||||
auto output_kv_page_offsets = mBufferManager->gpu(
|
||||
ITensor::makeShape({num_head_kv, batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
||||
auto output_seq_lengths
|
||||
= mBufferManager->gpu(ITensor::makeShape({num_head_kv, batch_size}), nvinfer1::DataType::kINT32);
|
||||
|
||||
// Create pinned host buffers for data initialization
|
||||
auto kv_page_offsets_host = mBufferManager->pinned(
|
||||
ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
||||
auto seq_lengths_host = mBufferManager->pinned(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
|
||||
auto sparse_indices_host
|
||||
= mBufferManager->pinned(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
|
||||
auto sparse_indices_offsets_host
|
||||
= mBufferManager->pinned(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
|
||||
|
||||
// Initialize test data
|
||||
auto kv_page_offsets_ptr = bufferCast<int32_t>(*kv_page_offsets_host);
|
||||
auto seq_lengths_ptr = bufferCast<int>(*seq_lengths_host);
|
||||
auto sparse_indices_ptr = bufferCast<int>(*sparse_indices_host);
|
||||
auto sparse_indices_offsets_ptr = bufferCast<int>(*sparse_indices_offsets_host);
|
||||
|
||||
// Initialize KV page offsets with test data
|
||||
for (int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
for (int d = 0; d < 2; ++d)
|
||||
{
|
||||
for (int p = 0; p < max_num_pages_per_seq; ++p)
|
||||
{
|
||||
int offset = b * 2 * max_num_pages_per_seq + d * max_num_pages_per_seq + p;
|
||||
kv_page_offsets_ptr[offset] = 1000 + b * 100 + d * 10 + p;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize sequence lengths
|
||||
seq_lengths_ptr[0] = 2 * tokens_per_page + 18; // 3 pages for batch 0
|
||||
seq_lengths_ptr[1] = 3 * tokens_per_page + 3; // 4 pages for batch 1
|
||||
|
||||
// Initialize sparse indices with different patterns for different heads
|
||||
// Shape: {total_sparse_pages, num_head_kv}
|
||||
// Each head can have its own sparse pattern
|
||||
int num_sparse_pages = 5;
|
||||
int sparse_page_indices[5][4] = {{1, 0, 0, 1}, {2, 1, 1, -1}, {-1, 2, -1, -1}, {0, 1, 2, 3}, {3, 3, 3, -1}};
|
||||
for (int page = 0; page < num_sparse_pages; ++page)
|
||||
{
|
||||
for (int head = 0; head < num_head_kv; ++head)
|
||||
{
|
||||
int idx = head * num_sparse_pages + page;
|
||||
sparse_indices_ptr[idx] = sparse_page_indices[page][head];
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize sparse indices offsets
|
||||
sparse_indices_offsets_ptr[0] = 0; // Start of batch 0
|
||||
sparse_indices_offsets_ptr[1] = 3; // Start of batch 1 (3 sparse pages for batch 0)
|
||||
sparse_indices_offsets_ptr[2] = 5; // End (3 sparse pages for batch 1)
|
||||
|
||||
// Copy data to GPU
|
||||
mBufferManager->copy(*kv_page_offsets_host, *kv_page_offsets);
|
||||
mBufferManager->copy(*seq_lengths_host, *seq_lengths);
|
||||
mBufferManager->copy(*sparse_indices_host, *sparse_indices);
|
||||
mBufferManager->copy(*sparse_indices_offsets_host, *sparse_indices_offsets);
|
||||
|
||||
SparseAttentionParams sparse_params;
|
||||
sparse_params.sparse_attn_indices = bufferCast<int32_t>(*sparse_indices);
|
||||
sparse_params.sparse_attn_offsets = bufferCast<int32_t>(*sparse_indices_offsets);
|
||||
|
||||
// Launch the kernel
|
||||
invokeGatherKvPageOffsets(bufferCast<int32_t>(*output_kv_page_offsets), bufferCast<int32_t>(*output_seq_lengths),
|
||||
bufferCast<int32_t>(*kv_page_offsets), bufferCast<int32_t>(*seq_lengths), sparse_params, batch_size,
|
||||
num_head_kv, tokens_per_page, max_num_pages_per_seq, mStream->get());
|
||||
|
||||
// Wait for completion
|
||||
mStream->synchronize();
|
||||
|
||||
// Copy results back to host for verification
|
||||
auto output_kv_page_offsets_host = mBufferManager->pinned(
|
||||
ITensor::makeShape({num_head_kv, batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
||||
auto output_seq_lengths_host
|
||||
= mBufferManager->pinned(ITensor::makeShape({num_head_kv, batch_size}), nvinfer1::DataType::kINT32);
|
||||
|
||||
mBufferManager->copy(*output_kv_page_offsets, *output_kv_page_offsets_host);
|
||||
mBufferManager->copy(*output_seq_lengths, *output_seq_lengths_host);
|
||||
|
||||
// Wait for completion
|
||||
mStream->synchronize();
|
||||
|
||||
auto output_kv_offsets_ptr = bufferCast<int32_t>(*output_kv_page_offsets_host);
|
||||
auto output_seq_len_ptr = bufferCast<int>(*output_seq_lengths_host);
|
||||
|
||||
// Verify sequence lengths for each head and batch
|
||||
int expected_seq_lens[4][2] = {
|
||||
{tokens_per_page + 18, tokens_per_page + 3}, // Head 0: batch 0 has 2 pages, batch 1 has 0 pages
|
||||
{2 * tokens_per_page + 18, tokens_per_page + 3}, // Head 1: batch 0 has 3 pages, batch 1 has 0 pages
|
||||
{2 * tokens_per_page, tokens_per_page + 3}, // Head 2: batch 0 has 2 pages, batch 1 has 0 pages
|
||||
{tokens_per_page, 3} // Head 3: batch 0 has 2 pages, batch 1 has 0 pages
|
||||
};
|
||||
|
||||
for (int h = 0; h < num_head_kv; ++h)
|
||||
{
|
||||
for (int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
int seq_len_idx = h * batch_size + b;
|
||||
EXPECT_EQ(output_seq_len_ptr[seq_len_idx], expected_seq_lens[h][b])
|
||||
<< "Sequence length mismatch at head=" << h << ", batch=" << b;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify gathered KV page offsets
|
||||
for (int h = 0; h < num_head_kv; ++h)
|
||||
{
|
||||
for (int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
int num_sparse_pages_batch = sparse_indices_offsets_ptr[b + 1] - sparse_indices_offsets_ptr[b];
|
||||
for (int d = 0; d < 2; ++d)
|
||||
{
|
||||
for (int p = 0; p < num_sparse_pages_batch; ++p)
|
||||
{
|
||||
// Calculate expected value (from the sparse index)
|
||||
int sparse_idx_global = sparse_indices_offsets_ptr[b] + p;
|
||||
int src_page_idx
|
||||
= sparse_indices_ptr[h * sparse_indices_offsets_ptr[batch_size] + sparse_idx_global];
|
||||
|
||||
if (src_page_idx == -1)
|
||||
{
|
||||
continue; // Skip invalid indices
|
||||
}
|
||||
|
||||
// Calculate output offset
|
||||
size_t output_offset = h * batch_size * 2 * max_num_pages_per_seq + b * 2 * max_num_pages_per_seq
|
||||
+ d * max_num_pages_per_seq + p;
|
||||
|
||||
int expected_value = 1000 + b * 100 + d * 10 + src_page_idx;
|
||||
|
||||
EXPECT_EQ(output_kv_offsets_ptr[output_offset], expected_value)
|
||||
<< "Mismatch at head=" << h << ", batch=" << b << ", dim=" << d << ", page=" << p;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
521
cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu
Normal file
521
cpp/tests/unit_tests/kernels/sparseKvCacheTest.cu
Normal file
@ -0,0 +1,521 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
||||
|
||||
using namespace tensorrt_llm::kernels;
|
||||
|
||||
class SparseKvCacheTest : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
mStream = nullptr;
|
||||
TLLM_CUDA_CHECK(cudaStreamCreate(&mStream));
|
||||
|
||||
// Test parameters
|
||||
mBatchSize = 2;
|
||||
mNumKvHeads = 4;
|
||||
mHeadSize = 128;
|
||||
mMaxSeqLen = 512;
|
||||
mTokensPerBlock = 64;
|
||||
mMaxBlocksPerSeq = 8;
|
||||
|
||||
// Allocate test data
|
||||
setupTestData();
|
||||
}
|
||||
|
||||
void TearDown() override
|
||||
{
|
||||
cleanup();
|
||||
if (mStream)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaStreamDestroy(mStream));
|
||||
}
|
||||
}
|
||||
|
||||
void setupTestData()
|
||||
{
|
||||
// Allocate device memory for sparse KV indices and offsets
|
||||
size_t sparse_indices_size = 20 * mNumKvHeads * sizeof(int); // 20 sparse tokens max
|
||||
size_t sparse_offsets_size = (mBatchSize + 1) * sizeof(int);
|
||||
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&mSparseKvIndicesDevice, sparse_indices_size));
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&mSparseKvOffsetsDevice, sparse_offsets_size));
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&mSeqLensDevice, mBatchSize * sizeof(int)));
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&mCacheSeqLensDevice, mBatchSize * sizeof(int)));
|
||||
|
||||
// Create sparse indices in the correct format: [sparse_token_idx][head_idx]
|
||||
// Total sparse tokens: 5 (batch 0) + 3 (batch 1) = 8
|
||||
std::vector<int> sparseKvIndicesHost;
|
||||
|
||||
// Batch 0: 5 sparse tokens per head
|
||||
std::vector<std::vector<int>> batch0_indices = {
|
||||
{1, 2, 3, 4, 5}, // head 0
|
||||
{3, 4, 5, 6, 8}, // head 1
|
||||
{0, 1, 3, 5, 8}, // head 2
|
||||
{1, 3, 5, 10, 11} // head 3
|
||||
};
|
||||
|
||||
// Batch 1: 3 sparse tokens per head
|
||||
std::vector<std::vector<int>> batch1_indices = {
|
||||
{1, 4, 7}, // head 0
|
||||
{0, 2, 3}, // head 1
|
||||
{1, 2, 7}, // head 2
|
||||
{1, 3, 7} // head 3
|
||||
};
|
||||
|
||||
// [num_kv_heads, num_sparse_kv_tokens]
|
||||
for (int head = 0; head < mNumKvHeads; ++head)
|
||||
{
|
||||
for (size_t token = 0; token < batch0_indices[head].size(); ++token)
|
||||
{
|
||||
sparseKvIndicesHost.push_back(batch0_indices[head][token]);
|
||||
}
|
||||
for (size_t token = 0; token < batch1_indices[head].size(); ++token)
|
||||
{
|
||||
sparseKvIndicesHost.push_back(batch1_indices[head][token]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> sparseKvOffsetsHost = {0, 5, 8}; // Batch 0: 5 tokens, Batch 1: 3 tokens
|
||||
std::vector<int> seqLensHost = {12, 8}; // Original sequence lengths
|
||||
std::vector<int> cacheSeqLensHost = {12, 8}; // Cache sequence lengths
|
||||
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(mSparseKvIndicesDevice, sparseKvIndicesHost.data(),
|
||||
sparseKvIndicesHost.size() * sizeof(int), cudaMemcpyHostToDevice));
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(mSparseKvOffsetsDevice, sparseKvOffsetsHost.data(),
|
||||
sparseKvOffsetsHost.size() * sizeof(int), cudaMemcpyHostToDevice));
|
||||
TLLM_CUDA_CHECK(
|
||||
cudaMemcpy(mSeqLensDevice, seqLensHost.data(), seqLensHost.size() * sizeof(int), cudaMemcpyHostToDevice));
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(mCacheSeqLensDevice, cacheSeqLensHost.data(), cacheSeqLensHost.size() * sizeof(int),
|
||||
cudaMemcpyHostToDevice));
|
||||
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
// Setup KV cache buffer using KVBlockArray
|
||||
setupKvCacheBuffer();
|
||||
}
|
||||
|
||||
void setupKvCacheBuffer()
|
||||
{
|
||||
// Calculate memory requirements
|
||||
auto const elemSize = sizeof(half);
|
||||
auto const sizePerToken = mNumKvHeads * mHeadSize * elemSize;
|
||||
auto const bytesPerBlock = mTokensPerBlock * sizePerToken;
|
||||
auto const totalBlocks = mBatchSize * mMaxBlocksPerSeq;
|
||||
auto const poolSize = totalBlocks * bytesPerBlock * 2; // K and V
|
||||
|
||||
// Allocate primary pool
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&mKvCachePool, poolSize));
|
||||
TLLM_CUDA_CHECK(cudaMemset(mKvCachePool, 0, poolSize));
|
||||
|
||||
// Allocate block offsets
|
||||
size_t offsetsSize = mBatchSize * mMaxBlocksPerSeq * 2 * sizeof(KVCacheIndex);
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&mBlockOffsetsDevice, offsetsSize));
|
||||
|
||||
// Initialize block offsets (simple linear mapping for test)
|
||||
std::vector<KVCacheIndex> blockOffsetsHost;
|
||||
blockOffsetsHost.reserve(mBatchSize * mMaxBlocksPerSeq * 2);
|
||||
|
||||
for (int batch = 0; batch < mBatchSize; ++batch)
|
||||
{
|
||||
for (int block = 0; block < mMaxBlocksPerSeq; ++block)
|
||||
{
|
||||
// K cache block offset
|
||||
int kBlockIdx = batch * mMaxBlocksPerSeq * 2 + block;
|
||||
blockOffsetsHost.emplace_back(kBlockIdx, false);
|
||||
}
|
||||
for (int block = 0; block < mMaxBlocksPerSeq; ++block)
|
||||
{
|
||||
// V cache block offset
|
||||
int vBlockIdx = batch * mMaxBlocksPerSeq * 2 + mMaxBlocksPerSeq + block;
|
||||
blockOffsetsHost.emplace_back(vBlockIdx, false);
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(mBlockOffsetsDevice, blockOffsetsHost.data(),
|
||||
blockOffsetsHost.size() * sizeof(KVCacheIndex), cudaMemcpyHostToDevice));
|
||||
|
||||
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
// Create KVBlockArray with correct parameter order:
|
||||
// (batchSize, maxBlocksPerSeq, tokensPerBlock, bytesPerToken,
|
||||
// maxAttentionWindow, maxAttentionWindowAllLayer, sinkTokenLen, canUseOneMoreBlock,
|
||||
// primaryPoolPtr, secondaryPoolPtr, data)
|
||||
mKvCacheBuffer = KVBlockArray(mBatchSize, mMaxBlocksPerSeq, mTokensPerBlock, sizePerToken, mMaxSeqLen,
|
||||
mMaxSeqLen, 0, false, mKvCachePool, nullptr, mBlockOffsetsDevice);
|
||||
}
|
||||
|
||||
void cleanup()
|
||||
{
|
||||
if (mSparseKvIndicesDevice)
|
||||
cudaFree(mSparseKvIndicesDevice);
|
||||
if (mSparseKvOffsetsDevice)
|
||||
cudaFree(mSparseKvOffsetsDevice);
|
||||
if (mSeqLensDevice)
|
||||
cudaFree(mSeqLensDevice);
|
||||
if (mCacheSeqLensDevice)
|
||||
cudaFree(mCacheSeqLensDevice);
|
||||
if (mKvCachePool)
|
||||
cudaFree(mKvCachePool);
|
||||
if (mBlockOffsetsDevice)
|
||||
cudaFree(mBlockOffsetsDevice);
|
||||
if (mQkvInputDevice)
|
||||
cudaFree(mQkvInputDevice);
|
||||
}
|
||||
|
||||
// Test parameters
|
||||
int mBatchSize;
|
||||
int mNumKvHeads;
|
||||
int mHeadSize;
|
||||
int mMaxSeqLen;
|
||||
int mTokensPerBlock;
|
||||
int mMaxBlocksPerSeq;
|
||||
|
||||
// Device memory
|
||||
int* mSparseKvIndicesDevice = nullptr;
|
||||
int* mSparseKvOffsetsDevice = nullptr;
|
||||
int* mSeqLensDevice = nullptr;
|
||||
int* mCacheSeqLensDevice = nullptr;
|
||||
void* mKvCachePool = nullptr;
|
||||
KVCacheIndex* mBlockOffsetsDevice = nullptr;
|
||||
half* mQkvInputDevice = nullptr;
|
||||
|
||||
KVBlockArray mKvCacheBuffer;
|
||||
cudaStream_t mStream;
|
||||
|
||||
// Helper functions for verification
|
||||
bool verifySparseKvCacheMapping(std::vector<half> const& originalKvCache);
|
||||
void performHostSparseMapping(std::vector<half> const& originalKvCache, std::vector<half>& expectedKvCache);
|
||||
void extractKvCacheFromGpu(std::vector<half>& kvCacheHost);
|
||||
void initializeKvCacheWithPattern();
|
||||
};
|
||||
|
||||
TEST_F(SparseKvCacheTest, UpdateSparseKvCacheAfterFmha)
|
||||
{
|
||||
// Allocate dummy QKV input (normally this would come from the attention computation)
|
||||
size_t qkvInputSize = mBatchSize * mMaxSeqLen * 3 * mNumKvHeads * mHeadSize * sizeof(half);
|
||||
TLLM_CUDA_CHECK(cudaMalloc(&mQkvInputDevice, qkvInputSize));
|
||||
|
||||
// Initialize with test pattern
|
||||
std::vector<half> qkvInputHost(qkvInputSize / sizeof(half));
|
||||
for (size_t i = 0; i < qkvInputHost.size(); ++i)
|
||||
{
|
||||
qkvInputHost[i] = half(float(i % 1000) / 100.0f); // Simple test pattern
|
||||
}
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(mQkvInputDevice, qkvInputHost.data(), qkvInputSize, cudaMemcpyHostToDevice));
|
||||
|
||||
// Initialize KV cache with initial data pattern for testing
|
||||
initializeKvCacheWithPattern();
|
||||
|
||||
// Extract the original KV cache data before kernel execution for verification
|
||||
size_t totalKvElements = mBatchSize * mMaxSeqLen * mNumKvHeads * mHeadSize * 2; // K and V
|
||||
std::vector<half> originalKvCache(totalKvElements);
|
||||
extractKvCacheFromGpu(originalKvCache);
|
||||
|
||||
// Setup QKVPreprocessingParams
|
||||
QKVPreprocessingParams<half, KVBlockArray> params;
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.qkv_input = mQkvInputDevice;
|
||||
params.kv_cache_buffer = mKvCacheBuffer;
|
||||
params.sparse_kv_indices = mSparseKvIndicesDevice;
|
||||
params.sparse_kv_offsets = mSparseKvOffsetsDevice;
|
||||
params.seq_lens = mSeqLensDevice;
|
||||
params.cache_seq_lens = mCacheSeqLensDevice;
|
||||
|
||||
params.batch_size = mBatchSize;
|
||||
params.head_num = mNumKvHeads; // For Q heads, assuming same as KV heads for this test
|
||||
params.kv_head_num = mNumKvHeads;
|
||||
params.size_per_head = mHeadSize;
|
||||
params.cache_type = KvCacheDataType::BASE;
|
||||
params.rotary_embedding_dim = 0; // No rotary embedding for this test
|
||||
|
||||
params.setCommonParameters();
|
||||
|
||||
// Verify sparse indices and offsets on host
|
||||
std::vector<int> hostSparseIndices(8 * mNumKvHeads);
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(hostSparseIndices.data(), mSparseKvIndicesDevice, hostSparseIndices.size() * sizeof(int),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
std::vector<int> hostSparseOffsets(mBatchSize + 1);
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(hostSparseOffsets.data(), mSparseKvOffsetsDevice, hostSparseOffsets.size() * sizeof(int),
|
||||
cudaMemcpyDeviceToHost));
|
||||
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
cudaError_t pre_kernel_error = cudaGetLastError();
|
||||
if (pre_kernel_error != cudaSuccess)
|
||||
{
|
||||
printf("Debug: CUDA error before kernel call: %s\n", cudaGetErrorString(pre_kernel_error));
|
||||
}
|
||||
|
||||
invokeUpdateSparseKvCacheAfterFmha<half, half, KVBlockArray>(params, mStream);
|
||||
|
||||
TLLM_CUDA_CHECK(cudaStreamSynchronize(mStream));
|
||||
cudaError_t post_kernel_error = cudaGetLastError();
|
||||
if (post_kernel_error != cudaSuccess)
|
||||
{
|
||||
printf("Debug: CUDA error after kernel call: %s\n", cudaGetErrorString(post_kernel_error));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Debug: Kernel call completed, no immediate CUDA errors\n");
|
||||
}
|
||||
|
||||
// Verification: Compare GPU result with CPU reference implementation
|
||||
EXPECT_TRUE(verifySparseKvCacheMapping(originalKvCache));
|
||||
}
|
||||
|
||||
// Implementation of verification functions
|
||||
bool SparseKvCacheTest::verifySparseKvCacheMapping(std::vector<half> const& originalKvCache)
|
||||
{
|
||||
// Perform host-side sparse mapping to get expected result
|
||||
size_t totalKvElements = originalKvCache.size();
|
||||
std::vector<half> expectedKvCache{originalKvCache};
|
||||
performHostSparseMapping(originalKvCache, expectedKvCache);
|
||||
|
||||
// Extract actual result from GPU after sparse kernel execution
|
||||
std::vector<half> actualKvCache(totalKvElements);
|
||||
extractKvCacheFromGpu(actualKvCache);
|
||||
|
||||
// Compare results with tolerance - only for valid sparse tokens
|
||||
float const tolerance = 1e-5f;
|
||||
bool passed = true;
|
||||
int errorCount = 0;
|
||||
int const maxErrorsToShow = 10;
|
||||
size_t totalValidElements = 0;
|
||||
|
||||
// Only compare the sparse tokens that should have been reorganized
|
||||
std::vector<int> sparseTokenCounts = {5, 3}; // Batch 0: 5 tokens, Batch 1: 3 tokens
|
||||
|
||||
for (int batch = 0; batch < mBatchSize; ++batch)
|
||||
{
|
||||
int numSparseTokens = sparseTokenCounts[batch];
|
||||
|
||||
for (int kv = 0; kv < 2; ++kv) // K and V
|
||||
{
|
||||
for (int token = 0; token < numSparseTokens; ++token)
|
||||
{
|
||||
for (int head = 0; head < mNumKvHeads; ++head)
|
||||
{
|
||||
for (int dim = 0; dim < mHeadSize; ++dim)
|
||||
{
|
||||
// Calculate index in flat array
|
||||
// Layout: [batch][kv][token][head][dim]
|
||||
size_t idx = batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize)
|
||||
+ kv * (mMaxSeqLen * mNumKvHeads * mHeadSize) + token * (mNumKvHeads * mHeadSize)
|
||||
+ head * mHeadSize + dim;
|
||||
|
||||
if (idx < totalKvElements)
|
||||
{
|
||||
float expected = float(expectedKvCache[idx]);
|
||||
float actual = float(actualKvCache[idx]);
|
||||
float diff = std::abs(expected - actual);
|
||||
|
||||
if (diff > tolerance)
|
||||
{
|
||||
if (errorCount < maxErrorsToShow)
|
||||
{
|
||||
printf(
|
||||
"Mismatch at batch=%d, kv=%d, token=%d, head=%d, dim=%d: expected %.6f, got "
|
||||
"%.6f, diff %.6f\n",
|
||||
batch, kv, token, head, dim, expected, actual, diff);
|
||||
}
|
||||
errorCount++;
|
||||
passed = false;
|
||||
}
|
||||
totalValidElements++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (errorCount > 0)
|
||||
{
|
||||
printf("Total errors: %d out of %zu valid sparse token elements\n", errorCount, totalValidElements);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Verification passed: all %zu valid sparse token elements match within tolerance %.2e\n",
|
||||
totalValidElements, tolerance);
|
||||
}
|
||||
|
||||
return passed;
|
||||
}
|
||||
|
||||
void SparseKvCacheTest::performHostSparseMapping(
|
||||
std::vector<half> const& originalKvCache, std::vector<half>& expectedKvCache)
|
||||
{
|
||||
// Host-side reference implementation of sparse KV cache mapping
|
||||
// This is a naive but correct implementation for verification
|
||||
|
||||
// Read sparse indices from GPU memory to get the actual data being used
|
||||
std::vector<int> hostSparseIndices(8 * mNumKvHeads);
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(hostSparseIndices.data(), mSparseKvIndicesDevice, hostSparseIndices.size() * sizeof(int),
|
||||
cudaMemcpyDeviceToHost));
|
||||
|
||||
std::vector<int> hostSparseOffsets(mBatchSize + 1);
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(hostSparseOffsets.data(), mSparseKvOffsetsDevice, hostSparseOffsets.size() * sizeof(int),
|
||||
cudaMemcpyDeviceToHost));
|
||||
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// Process each batch
|
||||
for (int batch = 0; batch < mBatchSize; ++batch)
|
||||
{
|
||||
int const sparse_start_idx = hostSparseOffsets[batch];
|
||||
int const sparse_end_idx = hostSparseOffsets[batch + 1];
|
||||
int const num_sparse_tokens = sparse_end_idx - sparse_start_idx;
|
||||
|
||||
// Process each head
|
||||
for (int head = 0; head < mNumKvHeads; ++head)
|
||||
{
|
||||
// Process both K and V cache
|
||||
for (int kv = 0; kv < 2; ++kv) // 0 = K, 1 = V
|
||||
{
|
||||
// For each sparse token in this batch
|
||||
for (int sparseTokenOffset = 0; sparseTokenOffset < num_sparse_tokens; ++sparseTokenOffset)
|
||||
{
|
||||
// Calculate index in new format: [num_kv_heads, num_sparse_kv_tokens]
|
||||
int const global_sparse_idx = sparse_start_idx + sparseTokenOffset;
|
||||
int const sparse_idx_offset = head * 8 + global_sparse_idx; // 8 is total sparse tokens
|
||||
int const originalTokenIdx = hostSparseIndices[sparse_idx_offset];
|
||||
int const continuousTokenIdx = sparseTokenOffset;
|
||||
|
||||
// Copy from original position to continuous position
|
||||
for (int dim = 0; dim < mHeadSize; ++dim)
|
||||
{
|
||||
// Calculate indices in the flat array
|
||||
// Layout: [batch][kv][token][head][dim]
|
||||
size_t srcIdx = batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize)
|
||||
+ kv * (mMaxSeqLen * mNumKvHeads * mHeadSize) + originalTokenIdx * (mNumKvHeads * mHeadSize)
|
||||
+ head * mHeadSize + dim;
|
||||
|
||||
size_t dstIdx = batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize)
|
||||
+ kv * (mMaxSeqLen * mNumKvHeads * mHeadSize)
|
||||
+ continuousTokenIdx * (mNumKvHeads * mHeadSize) + head * mHeadSize + dim;
|
||||
|
||||
if (srcIdx < originalKvCache.size() && dstIdx < expectedKvCache.size())
|
||||
{
|
||||
expectedKvCache[dstIdx] = originalKvCache[srcIdx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SparseKvCacheTest::extractKvCacheFromGpu(std::vector<half>& kvCacheHost)
|
||||
{
|
||||
// Extract KV cache data from GPU KVBlockArray structure
|
||||
// This is a simplified extraction for testing purposes
|
||||
|
||||
// Calculate total size needed
|
||||
size_t totalElements = mBatchSize * mMaxSeqLen * mNumKvHeads * mHeadSize * 2;
|
||||
kvCacheHost.resize(totalElements);
|
||||
|
||||
// For testing, we'll use a simplified approach to read back the cache
|
||||
// In a real implementation, this would need to handle the block structure properly
|
||||
|
||||
// Calculate pool size
|
||||
auto const elemSize = sizeof(half);
|
||||
auto const sizePerToken = mNumKvHeads * mHeadSize * elemSize;
|
||||
auto const bytesPerBlock = mTokensPerBlock * sizePerToken;
|
||||
auto const totalBlocks = mBatchSize * mMaxBlocksPerSeq;
|
||||
auto const poolSize = totalBlocks * bytesPerBlock * 2; // K and V
|
||||
|
||||
// Create temporary buffer to read entire pool
|
||||
std::vector<half> poolData(poolSize / sizeof(half));
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(poolData.data(), mKvCachePool, poolSize, cudaMemcpyDeviceToHost));
|
||||
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
// Reorganize from block structure to linear structure for comparison
|
||||
// This is a simplified mapping - in reality you'd need to handle block indexing
|
||||
for (int batch = 0; batch < mBatchSize; ++batch)
|
||||
{
|
||||
for (int token = 0; token < mMaxSeqLen; ++token)
|
||||
{
|
||||
for (int head = 0; head < mNumKvHeads; ++head)
|
||||
{
|
||||
for (int dim = 0; dim < mHeadSize; ++dim)
|
||||
{
|
||||
for (int kv = 0; kv < 2; ++kv)
|
||||
{
|
||||
// Calculate block coordinates
|
||||
int blockIdx = token / mTokensPerBlock;
|
||||
int tokenInBlock = token % mTokensPerBlock;
|
||||
|
||||
// Calculate source index in pool (simplified)
|
||||
// The layout of a block in the pool is [mTokensPerBlock, mNumKvHeads, mHeadSize]
|
||||
size_t block_base_pool_idx
|
||||
= (size_t) (batch * mMaxBlocksPerSeq * 2 + kv * mMaxBlocksPerSeq + blockIdx)
|
||||
* mTokensPerBlock * mNumKvHeads * mHeadSize;
|
||||
|
||||
size_t inner_block_pool_idx
|
||||
= (size_t) head * mTokensPerBlock * mHeadSize + (size_t) tokenInBlock * mHeadSize + dim;
|
||||
|
||||
size_t poolIdx = block_base_pool_idx + inner_block_pool_idx;
|
||||
|
||||
// Calculate destination index in linear layout
|
||||
size_t linearIdx = (size_t) batch * (2 * mMaxSeqLen * mNumKvHeads * mHeadSize)
|
||||
+ (size_t) kv * (mMaxSeqLen * mNumKvHeads * mHeadSize)
|
||||
+ (size_t) token * (mNumKvHeads * mHeadSize) + (size_t) head * mHeadSize + dim;
|
||||
|
||||
if (poolIdx < poolData.size() && linearIdx < kvCacheHost.size())
|
||||
{
|
||||
kvCacheHost[linearIdx] = poolData[poolIdx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SparseKvCacheTest::initializeKvCacheWithPattern()
|
||||
{
|
||||
|
||||
// Calculate pool size
|
||||
auto const elemSize = sizeof(half);
|
||||
auto const sizePerToken = mNumKvHeads * mHeadSize * elemSize;
|
||||
auto const bytesPerBlock = mTokensPerBlock * sizePerToken;
|
||||
auto const totalBlocks = mBatchSize * mMaxBlocksPerSeq;
|
||||
auto const poolSize = totalBlocks * bytesPerBlock * 2; // K and V
|
||||
|
||||
// Create host data with recognizable pattern
|
||||
std::vector<half> poolData(poolSize / sizeof(half));
|
||||
for (size_t i = 0; i < poolData.size(); ++i)
|
||||
{
|
||||
poolData[i] = half(float(i) / 1000.0f);
|
||||
}
|
||||
|
||||
// Copy to GPU
|
||||
TLLM_CUDA_CHECK(cudaMemcpy(mKvCachePool, poolData.data(), poolSize, cudaMemcpyHostToDevice));
|
||||
TLLM_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// Main function for standalone compilation
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
155
examples/llm-api/llm_sparse_attention.py
Normal file
155
examples/llm-api/llm_sparse_attention.py
Normal file
@ -0,0 +1,155 @@
|
||||
### :title Sparse Attention
|
||||
### :order 5
|
||||
### :section Customization
|
||||
"""
|
||||
This example demonstrates how to use sparse attention with TensorRT-LLM.
|
||||
|
||||
Supported sparse attention algorithms:
|
||||
- RocketKV
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
|
||||
```
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
|
||||
|
||||
|
||||
def read_input(input_file):
|
||||
results = []
|
||||
with open(input_file, 'r') as f:
|
||||
for line in f:
|
||||
ret = json.loads(line)
|
||||
results.append(ret)
|
||||
return results
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_path',
|
||||
type=str,
|
||||
default=
|
||||
"/home/scratch.trt_llm_data/llm-models/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input_file',
|
||||
type=str,
|
||||
default="tests/unittest/_torch/multi_gpu/test_star_attention_input.jsonl"
|
||||
)
|
||||
# Build config
|
||||
parser.add_argument('--algo',
|
||||
type=str,
|
||||
default='ROCKETKV',
|
||||
choices=['ROCKETKV'])
|
||||
parser.add_argument('--attention_backend',
|
||||
type=str,
|
||||
default='TRTLLM',
|
||||
choices=['VANILLA', 'TRTLLM'])
|
||||
parser.add_argument('--window_size',
|
||||
type=int,
|
||||
default=32,
|
||||
help="The window size for RocketKV.")
|
||||
parser.add_argument('--kernel_size',
|
||||
type=int,
|
||||
default=63,
|
||||
help="The kernel size for RocketKV.")
|
||||
parser.add_argument('--prompt_budget',
|
||||
type=int,
|
||||
default=2048,
|
||||
help="The prompt budget for RocketKV.")
|
||||
parser.add_argument("--max_seq_len",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="The maximum sequence length.")
|
||||
parser.add_argument("--max_batch_size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="The maximum batch size.")
|
||||
parser.add_argument("--max_new_tokens",
|
||||
type=int,
|
||||
default=128,
|
||||
help="The maximum new tokens.")
|
||||
parser.add_argument(
|
||||
"--max_num_tokens",
|
||||
type=int,
|
||||
default=8192,
|
||||
help=
|
||||
"The maximum total tokens (context + generation) across all sequences in a batch."
|
||||
)
|
||||
parser.add_argument('--tensor_parallel_size', type=int, default=1)
|
||||
|
||||
# KV cache
|
||||
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
|
||||
parser.add_argument("--kv_cache_fraction", type=float, default=None)
|
||||
parser.add_argument('--num_samples', type=int, default=10)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def run_RocketKV(args):
|
||||
data = read_input(args.input_file)
|
||||
num_samples = args.num_samples if args.num_samples is not None else len(
|
||||
data)
|
||||
data = data[:num_samples]
|
||||
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=
|
||||
False, # sparse attention does not support kv cache reuse now
|
||||
free_gpu_memory_fraction=args.kv_cache_fraction,
|
||||
dtype=args.kv_cache_dtype,
|
||||
)
|
||||
sparse_attention_config = RocketSparseAttentionConfig(
|
||||
window_size=args.window_size,
|
||||
kernel_size=args.kernel_size,
|
||||
prompt_budget=args.prompt_budget,
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model=args.model_path,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
attn_backend=args.attention_backend,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
cuda_graph_config=
|
||||
None, # sparse attention does not support cuda graph now
|
||||
)
|
||||
|
||||
prompts = []
|
||||
reference = []
|
||||
for sample in data:
|
||||
prompts.append(
|
||||
{'prompt': sample['input_context'] + sample['input_query']})
|
||||
reference.append(sample['outputs'])
|
||||
|
||||
sampling_params = SamplingParams(add_special_tokens=False,
|
||||
max_tokens=args.max_new_tokens,
|
||||
temperature=0.8,
|
||||
top_p=0.95)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for idx, output in enumerate(outputs):
|
||||
print(
|
||||
f'Generated text: {output.outputs[0].text!r}, ref: {reference[idx]}'
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
if args.algo == 'ROCKETKV':
|
||||
run_RocketKV(args)
|
||||
else:
|
||||
raise ValueError(f"Invalid algorithm: {args.algo}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
170
examples/longbench/README.md
Normal file
170
examples/longbench/README.md
Normal file
@ -0,0 +1,170 @@
|
||||
# LongBench Evaluation with TensorRT-LLM and Sparse Attention
|
||||
|
||||
This directory contains evaluation scripts for both LongBench v1 and LongBench v2 datasets using TensorRT-LLM backend.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
### 1. Clone LongBench Repository
|
||||
|
||||
First, clone the LongBench repository which contains the datasets and evaluation utilities:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/THUDM/LongBench.git
|
||||
```
|
||||
|
||||
### 2. Install Requirements
|
||||
|
||||
Install the required dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 3. Directory Structure
|
||||
|
||||
After cloning, your directory structure should look like:
|
||||
|
||||
```text
|
||||
sparse_attention/
|
||||
├── eval_longbench_v1.py # LongBench v1 evaluation script
|
||||
├── eval_longbench_v2.py # LongBench v2 evaluation script
|
||||
├── README.md # This file
|
||||
└── LongBench/ # Cloned LongBench repository
|
||||
├── LongBench/ # LongBench v1 data and configs
|
||||
│ ├── config/
|
||||
│ └── ...
|
||||
├── config/ # LongBench v2 configs
|
||||
├── ...
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
## Scripts Overview
|
||||
|
||||
### 1. eval_longbench_v1.py
|
||||
|
||||
This script evaluates models on the **LongBench v1** dataset, which includes multiple specific tasks like narrativeqa, qasper, multifieldqa, etc. Key features:
|
||||
|
||||
- **Dataset**: LongBench v1 with task-specific evaluation
|
||||
- **Tasks**: Support for 20+ different long-context tasks
|
||||
- **Prompts**: Task-specific prompts from LongBench v1 configuration
|
||||
- **Metrics**: Task-specific metrics (F1, ROUGE, classification scores, etc.)
|
||||
- **Output**: Task-level results with comprehensive summary statistics
|
||||
|
||||
### 2. eval_longbench_v2.py
|
||||
|
||||
This script evaluates models on the **LongBench v2** dataset, which is a standardized multiple-choice format. Key features:
|
||||
|
||||
- **Dataset**: LongBench v2 with unified multiple-choice format
|
||||
- **Format**: All questions are A/B/C/D multiple choice
|
||||
- **Context Length**: 8K to 2M words (majority under 128K)
|
||||
- **Difficulty**: Easy/Hard categorization
|
||||
- **Length**: Short/Medium/Long categorization
|
||||
- **Domains**: Various domains (single-doc QA, multi-doc QA, code, etc.)
|
||||
- **CoT Support**: Chain-of-Thought reasoning support
|
||||
- **Metrics**: Accuracy with breakdowns by difficulty, length, and domain
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### LongBench v1 Evaluation
|
||||
|
||||
#### Basic Usage (Standard Attention)
|
||||
```bash
|
||||
python eval_longbench_v1.py \
|
||||
--model_path "/path/to/your/model" \
|
||||
--longbench_path ./LongBench \
|
||||
--output_dir results/v1_vanilla \
|
||||
--attention_backend VANILLA \
|
||||
--backend pytorch
|
||||
```
|
||||
|
||||
#### Specific tasks With Sparse Attention (RocketKV)
|
||||
```bash
|
||||
python eval_longbench_v1.py \
|
||||
--model_path "/path/to/your/model" \
|
||||
--longbench_path ./LongBench \
|
||||
--dataset narrativeqa qasper \
|
||||
--output_dir results/v1_rocket \
|
||||
--attention_backend VANILLA \
|
||||
--backend pytorch \
|
||||
--rocket_sparse
|
||||
```
|
||||
|
||||
### LongBench v2 Evaluation
|
||||
|
||||
#### Basic Usage (Standard Attention)
|
||||
```bash
|
||||
python eval_longbench_v2.py \
|
||||
--model_path "/path/to/your/model" \
|
||||
--longbench_path ./LongBench \
|
||||
--output_dir results/v2_vanilla
|
||||
```
|
||||
|
||||
#### With Chain-of-Thought Reasoning
|
||||
```bash
|
||||
python eval_longbench_v2.py \
|
||||
--model_path "/path/to/your/model" \
|
||||
--longbench_path ./LongBench \
|
||||
--output_dir results/v2_cot \
|
||||
--cot
|
||||
```
|
||||
|
||||
#### Filter by Difficulty/Length/Domain
|
||||
```bash
|
||||
# Easy questions only
|
||||
python eval_longbench_v2.py \
|
||||
--model_path "/path/to/your/model" \
|
||||
--longbench_path ./LongBench \
|
||||
--output_dir results/v2_easy \
|
||||
--difficulty easy
|
||||
|
||||
# Long context only
|
||||
python eval_longbench_v2.py \
|
||||
--model_path "/path/to/your/model" \
|
||||
--longbench_path ./LongBench \
|
||||
--output_dir results/v2_long \
|
||||
--length long
|
||||
|
||||
# Specific domain
|
||||
python eval_longbench_v2.py \
|
||||
--model_path "/path/to/your/model" \
|
||||
--longbench_path ./LongBench \
|
||||
--output_dir results/v2_code \
|
||||
--domain "Code"
|
||||
```
|
||||
|
||||
#### Limited Sample Evaluation (for testing)
|
||||
```bash
|
||||
python eval_longbench_v2.py \
|
||||
--model_path "/path/to/your/model" \
|
||||
--longbench_path ./LongBench \
|
||||
--output_dir results/v2_test \
|
||||
--num_samples 10
|
||||
```
|
||||
|
||||
## Output Structure
|
||||
|
||||
### LongBench v1 Output
|
||||
|
||||
```text
|
||||
results/v1_experiment/
|
||||
├── config.json # Experiment configuration
|
||||
├── overall_summary.json # Overall experiment summary
|
||||
├── narrativeqa/
|
||||
│ ├── narrativeqa_results.jsonl # Detailed results
|
||||
│ ├── narrativeqa_summary.json # Task summary
|
||||
│ └── pred/
|
||||
│ └── narrativeqa.jsonl # Predictions in LongBench format
|
||||
├── qasper/
|
||||
│ └── ...
|
||||
└── ...
|
||||
```
|
||||
|
||||
### LongBench v2 Output
|
||||
|
||||
```text
|
||||
results/v2_experiment/
|
||||
├── config.json # Experiment configuration
|
||||
├── summary.json # Evaluation summary with metrics
|
||||
├── longbench_v2_results.jsonl # Detailed results
|
||||
└── predictions.jsonl # Predictions in LongBench v2 format
|
||||
```
|
||||
797
examples/longbench/eval_longbench_v1.py
Normal file
797
examples/longbench/eval_longbench_v1.py
Normal file
@ -0,0 +1,797 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LongBench v1 evaluation script with TensorRT-LLM and sparse attention.
|
||||
|
||||
Usage:
|
||||
python longbench_rocket_eval.py --dataset narrativeqa --model_path /path/to/model --longbench_path ./LongBench --output_dir results/
|
||||
|
||||
# Run all LongBench tasks
|
||||
python longbench_rocket_eval.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --token_budget 2048 --rocket_sparse
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add tensorrt_llm imports
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
LONGBENCH_DATASETS = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
|
||||
"dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
|
||||
"passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]
|
||||
|
||||
# Task categorization
|
||||
TASK_DATASETS = {
|
||||
'single_doc_qa':
|
||||
["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh"],
|
||||
'multi_doc_qa': ["hotpotqa", "2wikimqa", "musique", "dureader"],
|
||||
'summarization': ["gov_report", "qmsum", "multi_news", "vcsum"],
|
||||
'few_shots': ["trec", "triviaqa", "samsum", "lsht"],
|
||||
'synthetic':
|
||||
["passage_count", "passage_retrieval_en", "passage_retrieval_zh"],
|
||||
'code': ["lcc", "repobench-p"]
|
||||
}
|
||||
|
||||
# Chat templates mapping
|
||||
CHAT_TEMPLATES = {
|
||||
"llama3.1-8b-instruct": "llama3",
|
||||
"llama3-8b-instruct": "llama3",
|
||||
"mistral-7b-instruct-v0.2": "mistral",
|
||||
"longchat-7b-v1.5-32k": "vicuna"
|
||||
}
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LongBench evaluation with TensorRT-LLM and RocketKV")
|
||||
|
||||
# Model and data arguments
|
||||
parser.add_argument('--model_path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to model (HF model name or local path)')
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
nargs='+',
|
||||
choices=LONGBENCH_DATASETS,
|
||||
help='LongBench datasets to evaluate on')
|
||||
parser.add_argument('--run_all_tasks',
|
||||
action='store_true',
|
||||
help='Run evaluation on all LongBench tasks')
|
||||
parser.add_argument('--longbench_path',
|
||||
type=str,
|
||||
default='./LongBench',
|
||||
help='Path to LongBench directory')
|
||||
|
||||
# Output arguments
|
||||
parser.add_argument('--output_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Directory to save results')
|
||||
parser.add_argument('--exp_name',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Experiment name (auto-generated if not provided)')
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument('--attention_backend',
|
||||
type=str,
|
||||
default='VANILLA',
|
||||
choices=['VANILLA', 'TRTLLM', 'FLASHINFER'],
|
||||
help='Attention backend to use')
|
||||
parser.add_argument('--backend',
|
||||
type=str,
|
||||
default='pytorch',
|
||||
choices=['pytorch', 'tensorrt'],
|
||||
help='LLM backend to use')
|
||||
parser.add_argument('--chat_template',
|
||||
type=str,
|
||||
default='auto',
|
||||
help='Chat template to use (auto-detect if "auto")')
|
||||
|
||||
# Sequence and batch configuration
|
||||
parser.add_argument('--max_seq_len',
|
||||
type=int,
|
||||
default=133120,
|
||||
help='Maximum sequence length')
|
||||
parser.add_argument('--max_batch_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Maximum batch size')
|
||||
parser.add_argument('--max_new_tokens',
|
||||
type=int,
|
||||
default=256,
|
||||
help='Maximum new tokens to generate')
|
||||
parser.add_argument(
|
||||
'--max_num_tokens',
|
||||
type=int,
|
||||
default=133120,
|
||||
help='Maximum total tokens across all sequences in a batch')
|
||||
parser.add_argument('--tensor_parallel_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Tensor parallel size')
|
||||
|
||||
# RocketKV configuration
|
||||
parser.add_argument('--rocket_sparse',
|
||||
action='store_true',
|
||||
help='Use rocket sparse attention')
|
||||
parser.add_argument('--token_budget',
|
||||
type=int,
|
||||
default=2048,
|
||||
help='Token budget for RocketKV (prompt_budget)')
|
||||
parser.add_argument('--window_size',
|
||||
type=int,
|
||||
default=32,
|
||||
help='Window size for RocketKV')
|
||||
parser.add_argument('--kernel_size',
|
||||
type=int,
|
||||
default=63,
|
||||
help='Kernel size for RocketKV')
|
||||
parser.add_argument('--topr',
|
||||
type=int,
|
||||
default=90,
|
||||
help='Top-r for RocketKV')
|
||||
|
||||
# KV cache configuration
|
||||
parser.add_argument('--kv_cache_dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
help='KV cache data type')
|
||||
parser.add_argument('--kv_cache_fraction',
|
||||
type=float,
|
||||
default=0.7,
|
||||
help='Fraction of GPU memory for KV cache')
|
||||
|
||||
# Evaluation parameters
|
||||
parser.add_argument('--num_samples',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Number of samples to evaluate (None for all)')
|
||||
parser.add_argument('--start_idx',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Start index for evaluation')
|
||||
|
||||
# System arguments
|
||||
parser.add_argument('--log_level',
|
||||
type=str,
|
||||
default='info',
|
||||
choices=['debug', 'info', 'warning', 'error'],
|
||||
help='Logging level')
|
||||
parser.add_argument('--seed', type=int, default=42, help='Random seed')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validation
|
||||
if not args.run_all_tasks and not args.dataset:
|
||||
parser.error("Must specify either --dataset or --run_all_tasks")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def setup_longbench_imports(longbench_path: str):
|
||||
"""Add LongBench to Python path and import required modules."""
|
||||
longbench_dir = os.path.join(longbench_path, "LongBench") # for v1
|
||||
if not os.path.exists(longbench_dir):
|
||||
raise FileNotFoundError(
|
||||
f"LongBench directory not found: {longbench_dir}")
|
||||
|
||||
# Add to path
|
||||
if longbench_dir not in sys.path:
|
||||
sys.path.insert(0, longbench_dir)
|
||||
|
||||
|
||||
def load_longbench_config(
|
||||
longbench_path: str) -> Tuple[Dict[str, str], Dict[str, int]]:
|
||||
"""Load LongBench configuration files."""
|
||||
config_dir = os.path.join(longbench_path, "LongBench", "config")
|
||||
|
||||
# Load dataset2prompt.json
|
||||
prompt_file = os.path.join(config_dir, "dataset2prompt.json")
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
dataset2prompt = json.load(f)
|
||||
|
||||
# Load dataset2maxlen.json
|
||||
maxlen_file = os.path.join(config_dir, "dataset2maxlen.json")
|
||||
with open(maxlen_file, 'r', encoding='utf-8') as f:
|
||||
dataset2maxlen = json.load(f)
|
||||
|
||||
return dataset2prompt, dataset2maxlen
|
||||
|
||||
|
||||
# LongBench's build_chat function (simplified version)
|
||||
def build_chat(tokenizer, prompt, chat_template):
|
||||
"""Build chat prompt following LongBench's approach."""
|
||||
if chat_template == "vicuna" or chat_template == "longchat":
|
||||
try:
|
||||
from fastchat.model import get_conversation_template
|
||||
conv = get_conversation_template("vicuna")
|
||||
conv.append_message(conv.roles[0], prompt)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
except ImportError:
|
||||
# Fallback if fastchat is not available
|
||||
prompt = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: {prompt}\nASSISTANT:"
|
||||
elif chat_template == "llama2":
|
||||
prompt = f"[INST]{prompt}[/INST]"
|
||||
elif chat_template == "llama3":
|
||||
prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
elif chat_template == "mistral":
|
||||
prompt = f"[INST] {prompt} [/INST]"
|
||||
# For other templates or "none", return prompt as-is
|
||||
return prompt
|
||||
|
||||
|
||||
def determine_chat_template(model_path: str, chat_template: str) -> str:
|
||||
"""Determine chat template based on model path."""
|
||||
if chat_template != 'auto':
|
||||
return chat_template
|
||||
|
||||
model_path_lower = model_path.lower()
|
||||
|
||||
for model_key, template in CHAT_TEMPLATES.items():
|
||||
if model_key.replace('-', '').replace('.',
|
||||
'') in model_path_lower.replace(
|
||||
'-', '').replace('.', ''):
|
||||
return template
|
||||
|
||||
# Default fallback
|
||||
if 'llama' in model_path_lower:
|
||||
return 'llama3'
|
||||
elif 'mistral' in model_path_lower:
|
||||
return 'mistral'
|
||||
else:
|
||||
return 'none' # No special formatting
|
||||
|
||||
|
||||
def post_process(pred: str, chat_template: str, dataset: str) -> str:
|
||||
"""Post-process prediction following LongBench's approach."""
|
||||
pred = pred.split("</s")[0].strip()
|
||||
if chat_template == "qwen":
|
||||
pred = pred.split("<|im_end|>")[0]
|
||||
elif "llama2" in chat_template.lower():
|
||||
pred = (pred.split("(Document")[0].split("\n\nQuestion")[0].split(
|
||||
"\n\nAnswer")[0].split("[INST]")[0].split("[/INST]")[0].split(
|
||||
"(Passage")[0].strip())
|
||||
if dataset == "samsum":
|
||||
pred = pred.split("\n")[0].strip()
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
def format_prompt_style(sample: Dict[str, Any], instruction: str,
|
||||
chat_template: str, dataset: str, tokenizer) -> str:
|
||||
"""Format prompt following LongBench's approach."""
|
||||
# First format the instruction using the sample data
|
||||
prompt = instruction.format(**sample)
|
||||
|
||||
if dataset not in [
|
||||
"trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"
|
||||
]:
|
||||
prompt = build_chat(tokenizer, prompt, chat_template)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
|
||||
"""Initialize LLM and tokenizer."""
|
||||
logger.info(f"Initializing LLM with model: {args.model_path}")
|
||||
|
||||
try:
|
||||
# Configure KV cache
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False, # RocketKV doesn't support KV cache reuse
|
||||
)
|
||||
|
||||
if args.rocket_sparse:
|
||||
# Configure RocketKV sparse attention
|
||||
sparse_attention_config = RocketSparseAttentionConfig(
|
||||
window_size=args.window_size,
|
||||
kernel_size=args.kernel_size,
|
||||
prompt_budget=args.token_budget,
|
||||
topr=args.topr,
|
||||
)
|
||||
logger.info(f"Using RocketKV sparse attention")
|
||||
else:
|
||||
sparse_attention_config = None
|
||||
logger.info("Using standard attention")
|
||||
|
||||
# Initialize LLM
|
||||
llm = LLM(
|
||||
model=args.model_path,
|
||||
backend=args.backend,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_batch_size=args.max_batch_size,
|
||||
attn_backend=args.attention_backend,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
cuda_graph_config=None,
|
||||
torch_compile_config=None,
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
|
||||
logger.info("LLM and tokenizer initialized successfully")
|
||||
|
||||
return llm, tokenizer
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def evaluate_single_dataset(
|
||||
dataset: str, llm: LLM, tokenizer: AutoTokenizer,
|
||||
args: argparse.Namespace) -> Tuple[List[Dict], float]:
|
||||
"""Evaluate a single dataset."""
|
||||
setup_longbench_imports(args.longbench_path)
|
||||
|
||||
dataset2prompt, dataset2maxlen = load_longbench_config(args.longbench_path)
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {dataset}")
|
||||
data = [
|
||||
data_sample for data_sample in load_dataset(
|
||||
'THUDM/LongBench', dataset, split='test', trust_remote_code=True)
|
||||
]
|
||||
|
||||
# Apply data filtering
|
||||
if args.num_samples:
|
||||
end_idx = min(args.start_idx + args.num_samples, len(data))
|
||||
filtered_data = data[args.start_idx:end_idx]
|
||||
else:
|
||||
filtered_data = data[args.start_idx:]
|
||||
|
||||
logger.info(f"Dataset {dataset}: {len(filtered_data)} samples to evaluate")
|
||||
|
||||
# Determine chat template
|
||||
chat_template = determine_chat_template(args.model_path, args.chat_template)
|
||||
logger.info(f"Using chat template: {chat_template}")
|
||||
|
||||
# Create sampling parameters
|
||||
max_new_tokens = dataset2maxlen[dataset]
|
||||
prompt_format = dataset2prompt[dataset]
|
||||
|
||||
# Set up extra end token ids
|
||||
extra_end_token_ids = []
|
||||
if chat_template == "llama3":
|
||||
eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
|
||||
extra_end_token_ids.append(eot_id)
|
||||
logger.info(f"Added llama3 end token: {eot_id}")
|
||||
|
||||
if chat_template == "qwen":
|
||||
im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
|
||||
extra_end_token_ids.append(im_end_id)
|
||||
logger.info(f"Added qwen end token: {im_end_id}")
|
||||
|
||||
if dataset == "samsum":
|
||||
newline_id = tokenizer.encode("\n", add_special_tokens=False)[-1]
|
||||
extra_end_token_ids.append(newline_id)
|
||||
logger.info(f"Added samsum newline token: {newline_id}")
|
||||
|
||||
# Prepare prompts
|
||||
prompts = []
|
||||
for sample in filtered_data:
|
||||
formatted_prompt = format_prompt_style(sample, prompt_format,
|
||||
chat_template, dataset,
|
||||
tokenizer)
|
||||
prompts.append(formatted_prompt)
|
||||
|
||||
if len(prompts) == 0:
|
||||
logger.warning(f"No prompts to evaluate for dataset {dataset}")
|
||||
return [], 0.0
|
||||
|
||||
# Run inference
|
||||
logger.info(f"Starting inference for {len(prompts)} samples...")
|
||||
start_time = time.time()
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
stop_token_ids=extra_end_token_ids if extra_end_token_ids else None,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Inference completed in {inference_time:.2f} seconds, average time per sample: {inference_time/len(prompts):.3f} seconds"
|
||||
)
|
||||
|
||||
# Prepare results
|
||||
results = []
|
||||
for i, (sample, output) in enumerate(zip(filtered_data, outputs)):
|
||||
prediction = output.outputs[0].text.strip()
|
||||
processed_prediction = post_process(prediction, chat_template, dataset)
|
||||
|
||||
result = {
|
||||
'sample_id': args.start_idx + i,
|
||||
'input': sample.get('input', ''),
|
||||
'context': sample.get('context', ''),
|
||||
'answers': sample.get('answers', []),
|
||||
'all_classes': sample.get('all_classes', []),
|
||||
'prediction': processed_prediction,
|
||||
'raw_prediction': prediction,
|
||||
'prompt_length': len(output.prompt_token_ids),
|
||||
'output_length': len(output.outputs[0].token_ids),
|
||||
'inference_time': getattr(output, 'inference_time', None),
|
||||
'length': sample.get('length', 0)
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results, inference_time
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
dataset: str, predictions: List[str], answers_list: List[List[str]],
|
||||
all_classes_list: List[List[str]],
|
||||
longbench_path: str) -> Tuple[Dict[str, float], List[float]]:
|
||||
"""Calculate evaluation metrics for a dataset following LongBench's implementation."""
|
||||
|
||||
# Setup LongBench imports
|
||||
setup_longbench_imports(longbench_path)
|
||||
|
||||
# Import LongBench metrics
|
||||
from metrics import (classification_score, code_sim_score, count_score,
|
||||
qa_f1_score, qa_f1_zh_score, retrieval_score,
|
||||
retrieval_zh_score, rouge_score, rouge_zh_score)
|
||||
|
||||
# Mapping of datasets to their metric functions (from LongBench)
|
||||
dataset2metric = {
|
||||
"narrativeqa": qa_f1_score,
|
||||
"qasper": qa_f1_score,
|
||||
"multifieldqa_en": qa_f1_score,
|
||||
"multifieldqa_zh": qa_f1_zh_score,
|
||||
"hotpotqa": qa_f1_score,
|
||||
"2wikimqa": qa_f1_score,
|
||||
"musique": qa_f1_score,
|
||||
"dureader": rouge_zh_score,
|
||||
"gov_report": rouge_score,
|
||||
"qmsum": rouge_score,
|
||||
"multi_news": rouge_score,
|
||||
"vcsum": rouge_zh_score,
|
||||
"trec": classification_score,
|
||||
"triviaqa": qa_f1_score,
|
||||
"samsum": rouge_score,
|
||||
"lsht": classification_score,
|
||||
"passage_retrieval_en": retrieval_score,
|
||||
"passage_count": count_score,
|
||||
"passage_retrieval_zh": retrieval_zh_score,
|
||||
"lcc": code_sim_score,
|
||||
"repobench-p": code_sim_score,
|
||||
}
|
||||
|
||||
if dataset not in dataset2metric:
|
||||
# Fallback to simple exact match with cleaning
|
||||
total_score = 0
|
||||
scores = []
|
||||
for pred, answers in zip(predictions, answers_list):
|
||||
cleaned_pred = pred.lstrip('\n').split('\n')[0].strip()
|
||||
score = max([
|
||||
1.0 if cleaned_pred.lower() == ans.strip().lower() else 0.0
|
||||
for ans in answers
|
||||
])
|
||||
scores.append(score)
|
||||
total_score += score
|
||||
return {
|
||||
"exact_match": round(100 * total_score / len(predictions), 2)
|
||||
}, scores
|
||||
|
||||
metric_func = dataset2metric[dataset]
|
||||
total_score = 0.0
|
||||
scores = []
|
||||
|
||||
# Follow LongBench's scorer function exactly
|
||||
for pred, ground_truths, all_classes in zip(predictions, answers_list,
|
||||
all_classes_list):
|
||||
score = 0.0
|
||||
|
||||
# Apply the same prediction cleaning as LongBench
|
||||
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
|
||||
pred = pred.lstrip('\n').split('\n')[0]
|
||||
|
||||
# For code datasets, apply additional cleaning
|
||||
if dataset in ["lcc", "repobench-p"]:
|
||||
# This cleaning is done inside code_sim_score, but let's also apply it here for consistency
|
||||
all_lines = pred.lstrip('\n').split('\n')
|
||||
for line in all_lines:
|
||||
if ('`' not in line) and ('#' not in line) and ('//'
|
||||
not in line):
|
||||
pred = line
|
||||
break
|
||||
|
||||
# Calculate max score across all reference answers (exactly as in LongBench)
|
||||
for ground_truth in ground_truths:
|
||||
score = max(
|
||||
score, metric_func(pred, ground_truth, all_classes=all_classes))
|
||||
|
||||
scores.append(score)
|
||||
total_score += score
|
||||
|
||||
final_score = round(100 * total_score / len(predictions), 2)
|
||||
return {metric_func.__name__: final_score}, scores
|
||||
|
||||
|
||||
def calculate_task_summary(all_results: Dict[str, Dict]) -> Dict[str, Any]:
|
||||
"""Calculate task-level summary statistics following long_bench_tasks_summary.py approach."""
|
||||
logger.info("Calculating task-level summary statistics...")
|
||||
|
||||
summary = {}
|
||||
ind_dataset_result = {}
|
||||
task_ave_result = {}
|
||||
|
||||
NA_flag = False
|
||||
|
||||
# Get individual dataset results
|
||||
for dataset in LONGBENCH_DATASETS:
|
||||
if dataset in all_results and 'metrics' in all_results[dataset]:
|
||||
metrics = all_results[dataset]['metrics']
|
||||
# Get the first (and usually only) metric value
|
||||
if metrics:
|
||||
metric_key = list(metrics.keys())[0]
|
||||
val = metrics[metric_key]
|
||||
ind_dataset_result[dataset] = val
|
||||
else:
|
||||
ind_dataset_result[dataset] = 'N/A'
|
||||
NA_flag = True
|
||||
else:
|
||||
ind_dataset_result[dataset] = 'N/A'
|
||||
NA_flag = True
|
||||
|
||||
summary['individual_dataset_result'] = ind_dataset_result
|
||||
|
||||
# Calculate task-average results
|
||||
for task, datasets in TASK_DATASETS.items():
|
||||
task_NA_flag = False
|
||||
task_ave_result[task] = 0
|
||||
valid_count = 0
|
||||
|
||||
for dataset in datasets:
|
||||
if dataset in ind_dataset_result and ind_dataset_result[
|
||||
dataset] != 'N/A':
|
||||
task_ave_result[task] += ind_dataset_result[dataset]
|
||||
valid_count += 1
|
||||
else:
|
||||
task_NA_flag = True
|
||||
|
||||
if task_NA_flag or valid_count == 0:
|
||||
task_ave_result[task] = 'N/A'
|
||||
else:
|
||||
task_ave_result[task] = np.round(task_ave_result[task] /
|
||||
valid_count,
|
||||
decimals=2)
|
||||
|
||||
summary['task_average_result'] = task_ave_result
|
||||
|
||||
# Calculate overall LongBench average result (excluding passage_count as in original)
|
||||
if NA_flag:
|
||||
summary['LB_average_result'] = 'N/A'
|
||||
else:
|
||||
average_result = 0
|
||||
valid_count = 0
|
||||
for dataset in LONGBENCH_DATASETS:
|
||||
if dataset != 'passage_count' and dataset in ind_dataset_result:
|
||||
if ind_dataset_result[dataset] != 'N/A':
|
||||
average_result += ind_dataset_result[dataset]
|
||||
valid_count += 1
|
||||
|
||||
if valid_count > 0:
|
||||
summary['LB_average_result'] = np.round(average_result /
|
||||
valid_count,
|
||||
decimals=2)
|
||||
else:
|
||||
summary['LB_average_result'] = 'N/A'
|
||||
|
||||
# Log summary statistics
|
||||
logger.info("Task Summary Results:")
|
||||
logger.info(f"Overall LongBench Average: {summary['LB_average_result']}")
|
||||
for task, score in task_ave_result.items():
|
||||
logger.info(f"{task}: {score}")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def save_results(results: List[Dict], dataset: str, args: argparse.Namespace,
|
||||
inference_time: float, output_dir: str):
|
||||
"""Save evaluation results in format compatible with LongBench."""
|
||||
task_output_dir = os.path.join(output_dir, dataset)
|
||||
os.makedirs(task_output_dir, exist_ok=True)
|
||||
|
||||
# Extract predictions, answers, and all_classes for evaluation
|
||||
predictions = [r['prediction'] for r in results]
|
||||
answers_list = [r['answers'] for r in results]
|
||||
all_classes_list = [r.get('all_classes', []) for r in results]
|
||||
|
||||
# Calculate metrics
|
||||
processed_results, scores = calculate_metrics(dataset, predictions,
|
||||
answers_list,
|
||||
all_classes_list,
|
||||
args.longbench_path)
|
||||
logger.info(f"Evaluation metrics: {processed_results}")
|
||||
|
||||
# Save detailed results for manual inspection
|
||||
results_file = os.path.join(task_output_dir, f"{dataset}_results.jsonl")
|
||||
with open(results_file, 'w', encoding='utf-8') as f:
|
||||
for result in results:
|
||||
json.dump(result, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
# Save prediction results in LongBench format for evaluation
|
||||
pred_dir = os.path.join(task_output_dir, "pred")
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
pred_file = os.path.join(pred_dir, f"{dataset}.jsonl")
|
||||
|
||||
with open(pred_file, 'w', encoding='utf-8') as f:
|
||||
for idx, result in enumerate(results):
|
||||
pred_data = {
|
||||
"pred": result['prediction'],
|
||||
"answers": result['answers'],
|
||||
"all_classes": result.get('all_classes', []),
|
||||
"length": result.get('length', 0),
|
||||
"score": scores[idx]
|
||||
}
|
||||
json.dump(pred_data, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
# Create summary following LongBench format
|
||||
config = {
|
||||
'pipeline_params': {
|
||||
'model_name': args.model_path,
|
||||
'method': args.attention_backend,
|
||||
'token_budget': args.token_budget,
|
||||
'max_seq_len': args.max_seq_len,
|
||||
'max_new_tokens': args.max_new_tokens,
|
||||
'window_size': args.window_size,
|
||||
'kernel_size': args.kernel_size,
|
||||
'num_processes': 1, # Single process
|
||||
'devices': "0" # Single device
|
||||
},
|
||||
'eval_params': {
|
||||
'dataset': dataset,
|
||||
'num_samples': len(results)
|
||||
},
|
||||
'eval_results': {
|
||||
'processed_results': processed_results
|
||||
},
|
||||
'management': {
|
||||
'output_folder_dir': task_output_dir,
|
||||
'exp_desc':
|
||||
f'{dataset}_{os.path.basename(args.model_path)}_{args.attention_backend}_{args.token_budget}',
|
||||
'total_inference_time': inference_time,
|
||||
'avg_inference_time': inference_time / len(results),
|
||||
'evaluation_timestamp': datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Save summary
|
||||
summary_file = os.path.join(task_output_dir, f"{dataset}_summary.json")
|
||||
with open(summary_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"The results of {dataset} are saved to {task_output_dir}")
|
||||
|
||||
return processed_results
|
||||
|
||||
|
||||
def main():
|
||||
"""Main evaluation function."""
|
||||
args = parse_arguments()
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
# Setup experiment name
|
||||
if not args.exp_name:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
model_name = os.path.basename(args.model_path).replace('/', '_')
|
||||
args.exp_name = f"longbench_{model_name}_{timestamp}"
|
||||
|
||||
output_dir = os.path.join(args.output_dir, args.exp_name)
|
||||
|
||||
logger.info(
|
||||
"=========== LongBench Evaluation with TensorRT-LLM ===========")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# Save configuration
|
||||
config_file = os.path.join(output_dir, "config.json")
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(vars(args), f, indent=2)
|
||||
logger.info(f"Configuration saved to {config_file}")
|
||||
|
||||
# Determine datasets to evaluate
|
||||
if args.run_all_tasks:
|
||||
datasets = LONGBENCH_DATASETS
|
||||
logger.info(f"Running evaluation on full LongBench datasets")
|
||||
else:
|
||||
datasets = args.dataset
|
||||
logger.info(f"Running evaluation on datasets: {args.dataset}")
|
||||
|
||||
# Initialize LLM and tokenizer
|
||||
llm, tokenizer = initialize_llm(args)
|
||||
|
||||
# Process datasets sequentially
|
||||
all_results = {}
|
||||
for dataset_idx, dataset in enumerate(datasets):
|
||||
logger.info(f"{'='*30}")
|
||||
logger.info(
|
||||
f"Processing dataset {dataset_idx+1}/{len(datasets)}: {dataset}...")
|
||||
|
||||
# Evaluate the dataset
|
||||
results, inference_time = evaluate_single_dataset(
|
||||
dataset, llm, tokenizer, args)
|
||||
|
||||
# Save results and get metrics
|
||||
processed_results = save_results(results, dataset, args, inference_time,
|
||||
output_dir)
|
||||
|
||||
all_results[dataset] = {
|
||||
'num_samples': len(results),
|
||||
'inference_time': inference_time,
|
||||
'output_dir': output_dir,
|
||||
'metrics': processed_results
|
||||
}
|
||||
logger.info(f"Dataset {dataset} completed successfully")
|
||||
|
||||
total_time = sum(all_results[d]['inference_time'] for d in all_results)
|
||||
|
||||
# Calculate task-level summary
|
||||
task_summary = calculate_task_summary(all_results)
|
||||
|
||||
# Save overall summary with task statistics
|
||||
overall_summary = {
|
||||
'experiment_name':
|
||||
args.exp_name,
|
||||
'total_evaluation_time':
|
||||
total_time,
|
||||
'evaluated_datasets':
|
||||
list(all_results.keys()),
|
||||
'successful_datasets':
|
||||
[d for d, r in all_results.items() if 'error' not in r],
|
||||
'failed_datasets': [d for d, r in all_results.items() if 'error' in r],
|
||||
'results_by_dataset':
|
||||
all_results,
|
||||
'task_summary':
|
||||
task_summary, # Add task-level summary
|
||||
'configuration':
|
||||
vars(args)
|
||||
}
|
||||
|
||||
overall_summary_file = os.path.join(output_dir, "overall_summary.json")
|
||||
with open(overall_summary_file, 'w') as f:
|
||||
json.dump(overall_summary, f, indent=2)
|
||||
|
||||
logger.info(f"\n{'-'*80}")
|
||||
logger.info(
|
||||
f"Evaluation completed. Overall summary saved to: {overall_summary_file}"
|
||||
)
|
||||
logger.info(f"Total time: {total_time:.2f} seconds")
|
||||
|
||||
# Print final summary
|
||||
if task_summary['LB_average_result'] != 'N/A':
|
||||
logger.info(f"FINAL RESULTS:")
|
||||
logger.info(
|
||||
f"LongBench Overall Average: {task_summary['LB_average_result']}")
|
||||
logger.info(f"Task-level results:")
|
||||
for task, score in task_summary['task_average_result'].items():
|
||||
logger.info(f" {task}: {score}")
|
||||
|
||||
if overall_summary['failed_datasets']:
|
||||
logger.warning(f"Failed datasets: {overall_summary['failed_datasets']}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
771
examples/longbench/eval_longbench_v2.py
Normal file
771
examples/longbench/eval_longbench_v2.py
Normal file
@ -0,0 +1,771 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LongBench v2 evaluation script with TensorRT-LLM and sparse attention.
|
||||
|
||||
Usage:
|
||||
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/
|
||||
|
||||
# Run all LongBench v2 samples
|
||||
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/
|
||||
|
||||
# Enable CoT reasoning
|
||||
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --cot
|
||||
|
||||
# Run with different difficulty levels
|
||||
python eval_longbench_v2.py --model_path /path/to/model --longbench_path ./LongBench --output_dir results/ --difficulty easy
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Add tensorrt_llm imports
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# Chat templates mapping
|
||||
CHAT_TEMPLATES = {
|
||||
"llama3.1-8b-instruct": "llama3",
|
||||
"llama3-8b-instruct": "llama3",
|
||||
"mistral-7b-instruct-v0.2": "mistral",
|
||||
"longchat-7b-v1.5-32k": "vicuna"
|
||||
}
|
||||
|
||||
|
||||
def parse_arguments() -> argparse.Namespace:
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LongBench v2 evaluation with TensorRT-LLM and RocketKV")
|
||||
|
||||
# Model and data arguments
|
||||
parser.add_argument('--model_path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to model (HF model name or local path)')
|
||||
parser.add_argument('--longbench_path',
|
||||
type=str,
|
||||
default='./LongBench',
|
||||
help='Path to LongBench directory')
|
||||
|
||||
# Output arguments
|
||||
parser.add_argument('--output_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Directory to save results')
|
||||
parser.add_argument('--exp_name',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Experiment name (auto-generated if not provided)')
|
||||
|
||||
# Model configuration
|
||||
parser.add_argument('--attention_backend',
|
||||
type=str,
|
||||
default='VANILLA',
|
||||
choices=['VANILLA', 'TRTLLM', 'FLASHINFER'],
|
||||
help='Attention backend to use')
|
||||
parser.add_argument('--backend',
|
||||
type=str,
|
||||
default='pytorch',
|
||||
choices=['pytorch', 'tensorrt'],
|
||||
help='LLM backend to use')
|
||||
parser.add_argument('--chat_template',
|
||||
type=str,
|
||||
default='auto',
|
||||
help='Chat template to use (auto-detect if "auto")')
|
||||
|
||||
# Sequence and batch configuration
|
||||
parser.add_argument('--max_seq_len',
|
||||
type=int,
|
||||
default=133120,
|
||||
help='Maximum sequence length')
|
||||
parser.add_argument('--max_batch_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Maximum batch size')
|
||||
parser.add_argument('--max_new_tokens',
|
||||
type=int,
|
||||
default=256,
|
||||
help='Maximum new tokens to generate')
|
||||
parser.add_argument(
|
||||
'--max_num_tokens',
|
||||
type=int,
|
||||
default=133120,
|
||||
help='Maximum total tokens across all sequences in a batch')
|
||||
parser.add_argument('--tensor_parallel_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Tensor parallel size')
|
||||
|
||||
# RocketKV configuration
|
||||
parser.add_argument('--rocket_sparse',
|
||||
action='store_true',
|
||||
help='Use rocket sparse attention')
|
||||
parser.add_argument('--token_budget',
|
||||
type=int,
|
||||
default=2048,
|
||||
help='Token budget for RocketKV (prompt_budget)')
|
||||
parser.add_argument('--window_size',
|
||||
type=int,
|
||||
default=32,
|
||||
help='Window size for RocketKV')
|
||||
parser.add_argument('--kernel_size',
|
||||
type=int,
|
||||
default=63,
|
||||
help='Kernel size for RocketKV')
|
||||
parser.add_argument('--topr',
|
||||
type=int,
|
||||
default=90,
|
||||
help='Top-r for RocketKV')
|
||||
|
||||
# KV cache configuration
|
||||
parser.add_argument('--kv_cache_dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
help='KV cache data type')
|
||||
parser.add_argument('--kv_cache_fraction',
|
||||
type=float,
|
||||
default=0.7,
|
||||
help='Fraction of GPU memory for KV cache')
|
||||
|
||||
# LongBench v2 specific arguments
|
||||
parser.add_argument('--cot',
|
||||
action='store_true',
|
||||
help='Enable Chain-of-Thought reasoning')
|
||||
parser.add_argument('--no_context',
|
||||
action='store_true',
|
||||
help='Test without long context (pure memorization)')
|
||||
parser.add_argument('--rag',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Use top-N retrieved contexts (0 to disable)')
|
||||
|
||||
# Evaluation parameters
|
||||
parser.add_argument('--num_samples',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Number of samples to evaluate (None for all)')
|
||||
parser.add_argument('--start_idx',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Start index for evaluation')
|
||||
parser.add_argument('--difficulty',
|
||||
type=str,
|
||||
choices=['easy', 'hard'],
|
||||
default=None,
|
||||
help='Filter by difficulty level')
|
||||
parser.add_argument('--length',
|
||||
type=str,
|
||||
choices=['short', 'medium', 'long'],
|
||||
default=None,
|
||||
help='Filter by length category')
|
||||
parser.add_argument('--domain',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Filter by domain')
|
||||
parser.add_argument('--max_len',
|
||||
type=int,
|
||||
default=120000,
|
||||
help='Maximum prompt length in tokens for truncation')
|
||||
|
||||
# System arguments
|
||||
parser.add_argument('--log_level',
|
||||
type=str,
|
||||
default='info',
|
||||
choices=['debug', 'info', 'warning', 'error'],
|
||||
help='Logging level')
|
||||
parser.add_argument('--seed', type=int, default=42, help='Random seed')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_longbench_v2_config(longbench_path: str) -> Dict[str, Any]:
|
||||
"""Load LongBench v2 configuration files."""
|
||||
config_dir = os.path.join(longbench_path, "config")
|
||||
|
||||
# Load model2maxlen.json for v2
|
||||
maxlen_file = os.path.join(config_dir, "model2maxlen.json")
|
||||
with open(maxlen_file, 'r', encoding='utf-8') as f:
|
||||
model2maxlen = json.load(f)
|
||||
|
||||
# Load prompt templates
|
||||
prompts_dir = os.path.join(longbench_path, "prompts")
|
||||
|
||||
templates = {}
|
||||
template_files = {
|
||||
'0shot': '0shot.txt',
|
||||
'0shot_cot': '0shot_cot.txt',
|
||||
'0shot_cot_ans': '0shot_cot_ans.txt',
|
||||
'0shot_no_context': '0shot_no_context.txt',
|
||||
'0shot_rag': '0shot_rag.txt'
|
||||
}
|
||||
|
||||
for template_name, filename in template_files.items():
|
||||
template_path = os.path.join(prompts_dir, filename)
|
||||
if os.path.exists(template_path):
|
||||
with open(template_path, 'r', encoding='utf-8') as f:
|
||||
templates[template_name] = f.read()
|
||||
|
||||
return {'model2maxlen': model2maxlen, 'templates': templates}
|
||||
|
||||
|
||||
def build_chat(tokenizer, prompt, chat_template):
|
||||
"""Build chat prompt following LongBench's approach."""
|
||||
if chat_template == "vicuna" or chat_template == "longchat":
|
||||
try:
|
||||
from fastchat.model import get_conversation_template
|
||||
conv = get_conversation_template("vicuna")
|
||||
conv.append_message(conv.roles[0], prompt)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
except ImportError:
|
||||
# Fallback if fastchat is not available
|
||||
prompt = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\nUSER: {prompt}\nASSISTANT:"
|
||||
elif chat_template == "llama2":
|
||||
prompt = f"[INST]{prompt}[/INST]"
|
||||
elif chat_template == "llama3":
|
||||
prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
elif chat_template == "mistral":
|
||||
prompt = f"[INST] {prompt} [/INST]"
|
||||
# For other templates or "none", return prompt as-is
|
||||
return prompt
|
||||
|
||||
|
||||
def determine_chat_template(model_path: str, chat_template: str) -> str:
|
||||
"""Determine chat template based on model path."""
|
||||
if chat_template != 'auto':
|
||||
return chat_template
|
||||
|
||||
model_path_lower = model_path.lower()
|
||||
|
||||
for model_key, template in CHAT_TEMPLATES.items():
|
||||
if model_key.replace('-', '').replace('.',
|
||||
'') in model_path_lower.replace(
|
||||
'-', '').replace('.', ''):
|
||||
return template
|
||||
|
||||
# Default fallback
|
||||
if 'llama' in model_path_lower:
|
||||
return 'llama3'
|
||||
elif 'mistral' in model_path_lower:
|
||||
return 'mistral'
|
||||
else:
|
||||
return 'none' # No special formatting
|
||||
|
||||
|
||||
def extract_answer(response: str) -> Optional[str]:
|
||||
"""Extract answer from response following LongBench v2's approach."""
|
||||
response = response.replace('*', '')
|
||||
|
||||
# Try to extract answer in format "The correct answer is (X)"
|
||||
match = re.search(r'The correct answer is \(([A-D])\)', response)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Try to extract answer in format "The correct answer is X"
|
||||
match = re.search(r'The correct answer is ([A-D])', response)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Try to extract any single letter A-D
|
||||
match = re.search(r'\b([A-D])\b', response)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def post_process(pred: str, chat_template: str) -> str:
|
||||
"""Post-process prediction following LongBench v2's approach."""
|
||||
pred = pred.split("</s")[0].strip()
|
||||
if chat_template == "qwen":
|
||||
pred = pred.split("<|im_end|>")[0]
|
||||
elif "llama2" in chat_template.lower():
|
||||
pred = (pred.split("(Document")[0].split("\n\nQuestion")[0].split(
|
||||
"\n\nAnswer")[0].split("[INST]")[0].split("[/INST]")[0].split(
|
||||
"(Passage")[0].strip())
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
def truncate_prompt(prompt: str, tokenizer: AutoTokenizer, max_len: int) -> str:
|
||||
"""Truncate prompt following LongBench v2's approach."""
|
||||
# Encode the prompt using the tokenizer
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
|
||||
# If prompt exceeds max_len, truncate by taking first half and last half
|
||||
if len(input_ids) > max_len:
|
||||
half = max_len // 2
|
||||
truncated_ids = input_ids[:half] + input_ids[-half:]
|
||||
# Decode back to text
|
||||
prompt = tokenizer.decode(truncated_ids, skip_special_tokens=True)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def format_prompt(sample: Dict[str, Any], template: str,
|
||||
args: argparse.Namespace) -> str:
|
||||
"""Format prompt for LongBench v2."""
|
||||
context = sample['context']
|
||||
|
||||
# Handle RAG mode
|
||||
if args.rag > 0 and 'retrieved_context' in sample:
|
||||
retrieved = sample["retrieved_context"][:args.rag]
|
||||
retrieved = sorted(retrieved, key=lambda x: x.get('c_idx', 0))
|
||||
context = '\n\n'.join([
|
||||
f"Retrieved chunk {idx+1}: {x['content']}"
|
||||
for idx, x in enumerate(retrieved)
|
||||
])
|
||||
|
||||
# Handle no context mode
|
||||
if args.no_context:
|
||||
context = ""
|
||||
|
||||
# Format the prompt using the template
|
||||
prompt = template.replace('$DOC$', context.strip())
|
||||
prompt = prompt.replace('$Q$', sample['question'].strip())
|
||||
prompt = prompt.replace('$C_A$', sample['choice_A'].strip())
|
||||
prompt = prompt.replace('$C_B$', sample['choice_B'].strip())
|
||||
prompt = prompt.replace('$C_C$', sample['choice_C'].strip())
|
||||
prompt = prompt.replace('$C_D$', sample['choice_D'].strip())
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
|
||||
"""Initialize LLM and tokenizer."""
|
||||
logger.info(f"Initializing LLM with model: {args.model_path}")
|
||||
|
||||
try:
|
||||
# Configure KV cache
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False, # RocketKV doesn't support KV cache reuse
|
||||
)
|
||||
|
||||
if args.rocket_sparse:
|
||||
# Configure RocketKV sparse attention
|
||||
sparse_attention_config = RocketSparseAttentionConfig(
|
||||
window_size=args.window_size,
|
||||
kernel_size=args.kernel_size,
|
||||
prompt_budget=args.token_budget,
|
||||
topr=args.topr,
|
||||
)
|
||||
logger.info(f"Using RocketKV sparse attention")
|
||||
else:
|
||||
sparse_attention_config = None
|
||||
logger.info("Using standard attention")
|
||||
|
||||
# Initialize LLM
|
||||
llm = LLM(
|
||||
model=args.model_path,
|
||||
backend=args.backend,
|
||||
kv_cache_config=kv_cache_config,
|
||||
attn_backend=args.attention_backend,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
cuda_graph_config=None,
|
||||
torch_compile_config=None,
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
|
||||
logger.info("LLM and tokenizer initialized successfully")
|
||||
|
||||
return llm, tokenizer
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def evaluate_longbench_v2(llm: LLM, tokenizer: AutoTokenizer,
|
||||
args: argparse.Namespace) -> Tuple[List[Dict], float]:
|
||||
"""Evaluate on LongBench v2 dataset."""
|
||||
|
||||
# Load LongBench v2 configuration
|
||||
config = load_longbench_v2_config(args.longbench_path)
|
||||
|
||||
# Determine max_len for the model if not explicitly set via arguments
|
||||
model_name = os.path.basename(args.model_path)
|
||||
if model_name in config[
|
||||
'model2maxlen']: # Use default from config if available
|
||||
max_len = config['model2maxlen'][model_name]
|
||||
logger.info(f"Using model-specific max_len: {max_len} for {model_name}")
|
||||
else:
|
||||
max_len = args.max_len
|
||||
logger.info(f"Using max_len: {max_len}")
|
||||
|
||||
# Update args with the determined max_len
|
||||
args.max_len = max_len
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading LongBench v2 dataset...")
|
||||
dataset = load_dataset('THUDM/LongBench-v2',
|
||||
split='train',
|
||||
trust_remote_code=True)
|
||||
data = [item for item in dataset]
|
||||
|
||||
# Apply filters
|
||||
filtered_data = data
|
||||
|
||||
if args.difficulty:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if item['difficulty'] == args.difficulty
|
||||
]
|
||||
logger.info(
|
||||
f"Filtered by difficulty '{args.difficulty}': {len(filtered_data)} samples"
|
||||
)
|
||||
|
||||
if args.length:
|
||||
filtered_data = [
|
||||
item for item in filtered_data if item['length'] == args.length
|
||||
]
|
||||
logger.info(
|
||||
f"Filtered by length '{args.length}': {len(filtered_data)} samples")
|
||||
|
||||
if args.domain:
|
||||
filtered_data = [
|
||||
item for item in filtered_data if item['domain'] == args.domain
|
||||
]
|
||||
logger.info(
|
||||
f"Filtered by domain '{args.domain}': {len(filtered_data)} samples")
|
||||
|
||||
# Apply start_idx and num_samples
|
||||
if args.num_samples:
|
||||
end_idx = min(args.start_idx + args.num_samples, len(filtered_data))
|
||||
filtered_data = filtered_data[args.start_idx:end_idx]
|
||||
else:
|
||||
filtered_data = filtered_data[args.start_idx:]
|
||||
|
||||
logger.info(f"Final dataset size: {len(filtered_data)} samples")
|
||||
|
||||
# Determine chat template
|
||||
chat_template = determine_chat_template(args.model_path, args.chat_template)
|
||||
logger.info(f"Using chat template: {chat_template}")
|
||||
|
||||
logger.info(f"Prepare and truncate prompts...")
|
||||
# Select appropriate template
|
||||
if args.no_context:
|
||||
template_key = '0shot_no_context'
|
||||
elif args.rag > 0:
|
||||
template_key = '0shot_rag'
|
||||
elif args.cot:
|
||||
template_key = '0shot_cot'
|
||||
else:
|
||||
template_key = '0shot'
|
||||
|
||||
template = config['templates'][template_key]
|
||||
logger.info(f"Using template: {template_key}")
|
||||
|
||||
# Set up extra end token ids
|
||||
extra_end_token_ids = []
|
||||
if chat_template == "llama3":
|
||||
eot_id = tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
|
||||
extra_end_token_ids.append(eot_id)
|
||||
logger.info(f"Added llama3 end token: {eot_id}")
|
||||
|
||||
if chat_template == "qwen":
|
||||
im_end_id = tokenizer.encode("<|im_end|>", add_special_tokens=False)[0]
|
||||
extra_end_token_ids.append(im_end_id)
|
||||
logger.info(f"Added qwen end token: {im_end_id}")
|
||||
|
||||
# Prepare prompts
|
||||
prompts = []
|
||||
for sample in filtered_data:
|
||||
formatted_prompt = format_prompt(sample, template, args)
|
||||
|
||||
# Apply chat formatting if needed
|
||||
if chat_template != 'none':
|
||||
formatted_prompt = build_chat(tokenizer, formatted_prompt,
|
||||
chat_template)
|
||||
|
||||
# Apply truncation if prompt is too long
|
||||
formatted_prompt = truncate_prompt(formatted_prompt, tokenizer,
|
||||
args.max_len)
|
||||
|
||||
prompts.append(formatted_prompt)
|
||||
|
||||
if len(prompts) == 0:
|
||||
logger.warning(f"No prompts to evaluate")
|
||||
return [], 0.0
|
||||
|
||||
# Run inference
|
||||
logger.info(f"Starting inference for {len(prompts)} samples...")
|
||||
start_time = time.time()
|
||||
|
||||
# Set sampling parameters
|
||||
max_new_tokens = 1024 if args.cot else 256
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
stop_token_ids=extra_end_token_ids if extra_end_token_ids else None,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Inference completed in {inference_time:.2f} seconds, average time per sample: {inference_time/len(prompts):.3f} seconds"
|
||||
)
|
||||
|
||||
# Process outputs
|
||||
results = []
|
||||
for i, (sample, output) in enumerate(zip(filtered_data, outputs)):
|
||||
prediction = output.outputs[0].text.strip()
|
||||
processed_prediction = post_process(prediction, chat_template)
|
||||
|
||||
# Handle CoT mode
|
||||
if args.cot:
|
||||
# For CoT, we need to do a second inference to extract the final answer
|
||||
cot_response = processed_prediction
|
||||
|
||||
# Format the CoT answer extraction prompt
|
||||
cot_ans_template = config['templates']['0shot_cot_ans']
|
||||
cot_ans_prompt = format_prompt(sample, cot_ans_template, args)
|
||||
cot_ans_prompt = cot_ans_prompt.replace('$COT$', cot_response)
|
||||
|
||||
if chat_template != 'none':
|
||||
cot_ans_prompt = build_chat(tokenizer, cot_ans_prompt,
|
||||
chat_template)
|
||||
|
||||
# Apply truncation to CoT answer extraction prompt
|
||||
cot_ans_prompt = truncate_prompt(cot_ans_prompt, tokenizer,
|
||||
args.max_len)
|
||||
|
||||
# Generate final answer
|
||||
ans_sampling_params = SamplingParams(
|
||||
max_tokens=128,
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
stop_token_ids=extra_end_token_ids
|
||||
if extra_end_token_ids else None,
|
||||
)
|
||||
|
||||
ans_output = llm.generate([cot_ans_prompt], ans_sampling_params)[0]
|
||||
final_prediction = post_process(ans_output.outputs[0].text.strip(),
|
||||
chat_template)
|
||||
|
||||
extracted_answer = extract_answer(final_prediction)
|
||||
else:
|
||||
extracted_answer = extract_answer(processed_prediction)
|
||||
|
||||
# Calculate accuracy
|
||||
is_correct = extracted_answer == sample[
|
||||
'answer'] if extracted_answer else False
|
||||
|
||||
result = {
|
||||
'_id': sample['_id'],
|
||||
'domain': sample['domain'],
|
||||
'sub_domain': sample['sub_domain'],
|
||||
'difficulty': sample['difficulty'],
|
||||
'length': sample['length'],
|
||||
'question': sample['question'],
|
||||
'choice_A': sample['choice_A'],
|
||||
'choice_B': sample['choice_B'],
|
||||
'choice_C': sample['choice_C'],
|
||||
'choice_D': sample['choice_D'],
|
||||
'answer': sample['answer'],
|
||||
'prediction': processed_prediction,
|
||||
'extracted_answer': extracted_answer,
|
||||
'is_correct': is_correct,
|
||||
'context_length': len(sample['context']),
|
||||
'prompt_length': len(output.prompt_token_ids),
|
||||
'output_length': len(output.outputs[0].token_ids),
|
||||
}
|
||||
|
||||
if args.cot:
|
||||
result['cot_response'] = cot_response
|
||||
result['final_prediction'] = final_prediction
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results, inference_time
|
||||
|
||||
|
||||
def calculate_metrics(results: List[Dict]) -> Dict[str, Any]:
|
||||
"""Calculate evaluation metrics for LongBench v2."""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
total_samples = len(results)
|
||||
correct_samples = sum(1 for r in results if r['is_correct'])
|
||||
overall_accuracy = correct_samples / total_samples
|
||||
|
||||
metrics = {
|
||||
'overall_accuracy': round(overall_accuracy * 100, 2),
|
||||
'total_samples': total_samples,
|
||||
'correct_samples': correct_samples
|
||||
}
|
||||
|
||||
# Calculate metrics by difficulty
|
||||
difficulties = ['easy', 'hard']
|
||||
for difficulty in difficulties:
|
||||
diff_results = [r for r in results if r['difficulty'] == difficulty]
|
||||
if diff_results:
|
||||
diff_correct = sum(1 for r in diff_results if r['is_correct'])
|
||||
metrics[f'{difficulty}_accuracy'] = round(
|
||||
(diff_correct / len(diff_results)) * 100, 2)
|
||||
|
||||
# Calculate metrics by length
|
||||
lengths = ['short', 'medium', 'long']
|
||||
for length in lengths:
|
||||
len_results = [r for r in results if r['length'] == length]
|
||||
if len_results:
|
||||
len_correct = sum(1 for r in len_results if r['is_correct'])
|
||||
metrics[f'{length}_accuracy'] = round(
|
||||
(len_correct / len(len_results)) * 100, 2)
|
||||
|
||||
# Calculate metrics by domain
|
||||
domains = list(set(r['domain'] for r in results))
|
||||
for domain in domains:
|
||||
domain_results = [r for r in results if r['domain'] == domain]
|
||||
if domain_results:
|
||||
domain_correct = sum(1 for r in domain_results if r['is_correct'])
|
||||
metrics[f'{domain}_accuracy'] = round(
|
||||
(domain_correct / len(domain_results)) * 100, 2)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def save_results(results: List[Dict], args: argparse.Namespace,
|
||||
inference_time: float, output_dir: str):
|
||||
"""Save evaluation results in format compatible with LongBench v2."""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Calculate metrics
|
||||
metrics = calculate_metrics(results)
|
||||
logger.info(f"Evaluation metrics: {metrics}")
|
||||
|
||||
# Save detailed results
|
||||
results_file = os.path.join(output_dir, "longbench_v2_results.jsonl")
|
||||
with open(results_file, 'w', encoding='utf-8') as f:
|
||||
for result in results:
|
||||
json.dump(result, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
# Save prediction results in LongBench v2 format
|
||||
pred_file = os.path.join(output_dir, "predictions.jsonl")
|
||||
with open(pred_file, 'w', encoding='utf-8') as f:
|
||||
for result in results:
|
||||
pred_data = {
|
||||
"_id": result['_id'],
|
||||
"prediction": result['extracted_answer'],
|
||||
"response": result['prediction'],
|
||||
"judge": result['is_correct']
|
||||
}
|
||||
if args.cot:
|
||||
pred_data['cot_response'] = result.get('cot_response', '')
|
||||
pred_data['final_prediction'] = result.get(
|
||||
'final_prediction', '')
|
||||
|
||||
json.dump(pred_data, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
|
||||
# Create summary
|
||||
summary = {
|
||||
'experiment_config': {
|
||||
'model_path': args.model_path,
|
||||
'attention_backend': args.attention_backend,
|
||||
'rocket_sparse': args.rocket_sparse,
|
||||
'token_budget': args.token_budget,
|
||||
'cot': args.cot,
|
||||
'no_context': args.no_context,
|
||||
'rag': args.rag,
|
||||
'difficulty_filter': args.difficulty,
|
||||
'length_filter': args.length,
|
||||
'domain_filter': args.domain,
|
||||
'max_seq_len': args.max_seq_len,
|
||||
'max_new_tokens': args.max_new_tokens
|
||||
},
|
||||
'evaluation_results': metrics,
|
||||
'timing': {
|
||||
'total_inference_time': inference_time,
|
||||
'avg_inference_time':
|
||||
inference_time / len(results) if results else 0,
|
||||
'evaluation_timestamp': datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Save summary
|
||||
summary_file = os.path.join(output_dir, "summary.json")
|
||||
with open(summary_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(summary, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Results saved to {output_dir}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def main():
|
||||
"""Main evaluation function."""
|
||||
args = parse_arguments()
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
# Setup experiment name
|
||||
if not args.exp_name:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
model_name = os.path.basename(args.model_path).replace('/', '_')
|
||||
args.exp_name = f"longbench_v2_{model_name}_{timestamp}"
|
||||
|
||||
output_dir = os.path.join(args.output_dir, args.exp_name)
|
||||
|
||||
logger.info(
|
||||
"=========== LongBench v2 Evaluation with TensorRT-LLM ===========")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save configuration
|
||||
config_file = os.path.join(output_dir, "config.json")
|
||||
with open(config_file, 'w') as f:
|
||||
json.dump(vars(args), f, indent=2)
|
||||
logger.info(f"Configuration saved to {config_file}")
|
||||
|
||||
# Initialize LLM and tokenizer
|
||||
llm, tokenizer = initialize_llm(args)
|
||||
|
||||
# Run evaluation
|
||||
logger.info(f"Starting LongBench v2 evaluation...")
|
||||
results, inference_time = evaluate_longbench_v2(llm, tokenizer, args)
|
||||
|
||||
# Save results and get metrics
|
||||
metrics = save_results(results, args, inference_time, output_dir)
|
||||
|
||||
logger.info(f"{'-'*80}")
|
||||
logger.info(f"Evaluation completed successfully!")
|
||||
logger.info(f"Total time: {inference_time:.2f} seconds")
|
||||
logger.info(f"Total samples: {len(results)}")
|
||||
|
||||
if metrics:
|
||||
logger.info(
|
||||
f"Overall accuracy: {metrics.get('overall_accuracy', 'N/A')}%")
|
||||
|
||||
if 'easy_accuracy' in metrics:
|
||||
logger.info(
|
||||
f"Easy difficulty accuracy: {metrics['easy_accuracy']}% ({metrics.get('easy_samples', 0)} samples)"
|
||||
)
|
||||
if 'hard_accuracy' in metrics:
|
||||
logger.info(
|
||||
f"Hard difficulty accuracy: {metrics['hard_accuracy']}% ({metrics.get('hard_samples', 0)} samples)"
|
||||
)
|
||||
|
||||
for length in ['short', 'medium', 'long']:
|
||||
if f'{length}_accuracy' in metrics:
|
||||
logger.info(
|
||||
f"{length.capitalize()} length accuracy: {metrics[f'{length}_accuracy']}% ({metrics.get(f'{length}_samples', 0)} samples)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
3
examples/longbench/requirements.txt
Normal file
3
examples/longbench/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
jieba
|
||||
fuzzywuzzy
|
||||
rouge
|
||||
@ -1,5 +1,6 @@
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from .interface import AttentionBackend, AttentionMetadata
|
||||
from .sparse import get_sparse_attn_kv_cache_manager
|
||||
from .trtllm import AttentionInputType, TrtllmAttention, TrtllmAttentionMetadata
|
||||
from .vanilla import VanillaAttention, VanillaAttentionMetadata
|
||||
|
||||
@ -11,6 +12,7 @@ __all__ = [
|
||||
"TrtllmAttentionMetadata",
|
||||
"VanillaAttention",
|
||||
"VanillaAttentionMetadata",
|
||||
"get_sparse_attn_kv_cache_manager",
|
||||
]
|
||||
|
||||
if IS_FLASHINFER_AVAILABLE:
|
||||
|
||||
@ -140,6 +140,7 @@ class AttentionMetadata:
|
||||
|
||||
_saved_tensors: Dict[str, torch.Tensor] = field(init=False,
|
||||
default_factory=dict)
|
||||
sparse_attention_config: Optional["SparseAttentionConfig"] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.is_cross:
|
||||
@ -563,6 +564,7 @@ class AttentionBackend(Generic[TMetadata]):
|
||||
num_kv_heads: Optional[int] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
skip_create_weights_in_init: bool = False,
|
||||
sparse_attention_config: Optional["SparseAttentionConfig"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -573,12 +575,14 @@ class AttentionBackend(Generic[TMetadata]):
|
||||
head_dim (int): The size of each attention head (hidden_size // num_heads).
|
||||
num_kv_heads (int): The number of kv heads. Defaults to num_heads if None.
|
||||
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
|
||||
sparse_attention_config (SparseAttentionConfig): Optional sparse attention configuration. If None, no sparse attention is applied.
|
||||
"""
|
||||
self.layer_idx = layer_idx
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_kv_heads = num_kv_heads or self.num_heads
|
||||
self.quant_config = quant_config
|
||||
self.sparse_attention_config = sparse_attention_config
|
||||
|
||||
def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
|
||||
"""
|
||||
|
||||
11
tensorrt_llm/_torch/attention_backend/sparse/__init__.py
Normal file
11
tensorrt_llm/_torch/attention_backend/sparse/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .utils import (get_flashinfer_sparse_attn_attention_backend,
|
||||
get_sparse_attn_kv_cache_manager,
|
||||
get_trtllm_sparse_attn_attention_backend,
|
||||
get_vanilla_sparse_attn_attention_backend)
|
||||
|
||||
__all__ = [
|
||||
"get_sparse_attn_kv_cache_manager",
|
||||
"get_vanilla_sparse_attn_attention_backend",
|
||||
"get_trtllm_sparse_attn_attention_backend",
|
||||
"get_flashinfer_sparse_attn_attention_backend",
|
||||
]
|
||||
308
tensorrt_llm/_torch/attention_backend/sparse/kernel.py
Normal file
308
tensorrt_llm/_torch/attention_backend/sparse/kernel.py
Normal file
@ -0,0 +1,308 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _index_gather_kernel(output_ptr, input_ptr, index_ptr, in_row_stride,
|
||||
in_token_stride, in_head_stride, idx_row_stride,
|
||||
idx_token_stride, idx_head_stride, dim_size,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
# get program id and block offset
|
||||
row_pid = tl.program_id(0)
|
||||
token_pid = tl.program_id(1)
|
||||
head_pid = tl.program_id(2)
|
||||
token_block_num = tl.num_programs(1)
|
||||
head_num = tl.num_programs(2)
|
||||
|
||||
# get index
|
||||
indices_idx = row_pid * idx_row_stride + token_pid * idx_token_stride + head_pid * idx_head_stride
|
||||
token_idx = tl.load(index_ptr + indices_idx)
|
||||
|
||||
# get input and output base address
|
||||
input_base = (row_pid * in_row_stride + token_idx * in_token_stride +
|
||||
head_pid * in_head_stride)
|
||||
output_base = (row_pid * token_block_num * head_num * dim_size +
|
||||
token_pid * head_num * dim_size + head_pid * dim_size)
|
||||
|
||||
# process elements in blocks
|
||||
for dim_offset in tl.range(0, dim_size, BLOCK_SIZE):
|
||||
# get offsets
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
dim_indices = dim_offset + offsets
|
||||
mask = dim_indices < dim_size
|
||||
|
||||
# load input and store output
|
||||
input_val = tl.load(input_ptr + input_base + dim_indices,
|
||||
mask=mask,
|
||||
other=0.0)
|
||||
tl.store(output_ptr + output_base + dim_indices, input_val, mask=mask)
|
||||
|
||||
|
||||
def triton_index_gather(input, indices):
|
||||
assert input.ndim == 4, "Input must be a 4D tensor, [row, token, head, dim]"
|
||||
assert indices.ndim == 3, "Indices must be a 3D tensor, [row, token, head]"
|
||||
|
||||
# shape of input and indices
|
||||
row_size = input.shape[0]
|
||||
head_num = input.shape[2]
|
||||
dim_size = input.shape[3]
|
||||
num_tokens = indices.shape[1]
|
||||
|
||||
# create output tensor
|
||||
output = torch.empty((row_size, num_tokens, head_num, dim_size),
|
||||
device='cuda',
|
||||
dtype=input.dtype)
|
||||
|
||||
# launch kernel
|
||||
grid = (row_size, num_tokens, head_num)
|
||||
_index_gather_kernel[grid](output,
|
||||
input,
|
||||
indices,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
input.stride(2),
|
||||
indices.stride(0),
|
||||
indices.stride(1),
|
||||
indices.stride(2),
|
||||
dim_size,
|
||||
BLOCK_SIZE=1024)
|
||||
return output
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _update_kt_cache_ctx_kernel(k_ptr, cache_ptr, block_offsets_ptr,
|
||||
cum_seq_lens_ptr, cum_kt_seq_lens_ptr,
|
||||
token_to_batch_map_ptr, num_kv_heads, dim_size,
|
||||
kt_page_size, tokens_per_block,
|
||||
max_kt_blocks_per_seq,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
# get program id
|
||||
kt_token_idx = tl.program_id(0)
|
||||
|
||||
# get params
|
||||
batch_idx = tl.load(token_to_batch_map_ptr + kt_token_idx)
|
||||
kv_start_idx = tl.load(cum_seq_lens_ptr + batch_idx)
|
||||
kv_end_idx = tl.load(cum_seq_lens_ptr + batch_idx + 1)
|
||||
kt_start_idx = tl.load(cum_kt_seq_lens_ptr + batch_idx)
|
||||
local_kt_token_idx = kt_token_idx - kt_start_idx
|
||||
global_kv_token_idx = kv_start_idx + local_kt_token_idx * kt_page_size
|
||||
|
||||
# get offsets
|
||||
hidden_size = num_kv_heads * dim_size
|
||||
k_base = k_ptr + global_kv_token_idx * hidden_size
|
||||
block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block
|
||||
block_idx = tl.load(block_offsets_ptr + block_offset)
|
||||
token_idx_in_block = local_kt_token_idx % tokens_per_block
|
||||
cache_base = cache_ptr + (block_idx * tokens_per_block +
|
||||
token_idx_in_block) * hidden_size * 2
|
||||
|
||||
# compute min/max and store kt
|
||||
for hidden_start in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
|
||||
head_idx = hidden_indices // dim_size
|
||||
dim_idx = hidden_indices % dim_size
|
||||
dim_mask = hidden_indices < hidden_size
|
||||
|
||||
# get k_min and k_max
|
||||
k_min = tl.full((BLOCK_SIZE, ), float('inf'), dtype=tl.float32)
|
||||
k_max = tl.full((BLOCK_SIZE, ), float('-inf'), dtype=tl.float32)
|
||||
for i in range(kt_page_size):
|
||||
if global_kv_token_idx + i < kv_end_idx:
|
||||
k = tl.load(k_base + i * hidden_size + hidden_indices,
|
||||
mask=dim_mask,
|
||||
other=0.0)
|
||||
k_min = tl.minimum(k_min, k)
|
||||
k_max = tl.maximum(k_max, k)
|
||||
k_min = k_min.to(cache_ptr.dtype.element_ty)
|
||||
k_max = k_max.to(cache_ptr.dtype.element_ty)
|
||||
|
||||
# store k_min and k_max to cache
|
||||
k_min_offset = cache_base + head_idx * dim_size * 2 + dim_idx
|
||||
k_max_offset = k_min_offset + dim_size
|
||||
tl.store(k_min_offset, k_min, mask=dim_mask)
|
||||
tl.store(k_max_offset, k_max, mask=dim_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _update_kt_cache_gen_kernel(k_ptr, cache_ptr, block_offsets_ptr,
|
||||
seq_lens_ptr, num_kv_heads, dim_size,
|
||||
kt_page_size, tokens_per_block,
|
||||
max_kt_blocks_per_seq,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
# get program id
|
||||
batch_idx = tl.program_id(0)
|
||||
head_idx = tl.program_id(1)
|
||||
|
||||
# get params
|
||||
past_key_value_length = tl.load(seq_lens_ptr + batch_idx) - 1
|
||||
kt_token_idx = past_key_value_length // kt_page_size
|
||||
kt_token_idx_in_page = past_key_value_length % kt_page_size
|
||||
|
||||
# get offsets
|
||||
hidden_size = num_kv_heads * dim_size
|
||||
k_base = k_ptr + batch_idx * hidden_size + head_idx * dim_size
|
||||
block_offset = batch_idx * max_kt_blocks_per_seq + kt_token_idx // tokens_per_block
|
||||
block_idx = tl.load(block_offsets_ptr + block_offset)
|
||||
kt_token_idx_in_block = kt_token_idx % tokens_per_block
|
||||
cache_base = cache_ptr + (block_idx * tokens_per_block +
|
||||
kt_token_idx_in_block) * hidden_size * 2
|
||||
cache_base += head_idx * dim_size * 2
|
||||
|
||||
# update kt cache
|
||||
for hidden_start in tl.range(0, dim_size, BLOCK_SIZE):
|
||||
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
|
||||
dim_mask = hidden_indices < dim_size
|
||||
|
||||
# load k
|
||||
k = tl.load(k_base + hidden_indices, mask=dim_mask, other=0.0)
|
||||
|
||||
# load kt cache
|
||||
kt_mask = dim_mask & (kt_token_idx_in_page > 0)
|
||||
k_min = tl.load(cache_base + hidden_indices,
|
||||
mask=kt_mask,
|
||||
other=float('inf'))
|
||||
k_max = tl.load(cache_base + hidden_indices + dim_size,
|
||||
mask=kt_mask,
|
||||
other=float('-inf'))
|
||||
k_min = tl.minimum(k_min, k)
|
||||
k_max = tl.maximum(k_max, k)
|
||||
k_min = k_min.to(cache_ptr.dtype.element_ty)
|
||||
k_max = k_max.to(cache_ptr.dtype.element_ty)
|
||||
|
||||
# store k_min and k_max to cache
|
||||
tl.store(cache_base + hidden_indices, k_min, mask=dim_mask)
|
||||
tl.store(cache_base + hidden_indices + dim_size, k_max, mask=dim_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _load_kt_cache_kernel(kt_states_ptr, cache_ptr, block_offsets_ptr,
|
||||
cum_kt_seq_lens_ptr, token_to_batch_map_ptr,
|
||||
num_kv_heads, dim_size, tokens_per_block,
|
||||
max_kt_blocks_per_seq, BLOCK_SIZE: tl.constexpr):
|
||||
# get program id
|
||||
kt_token_idx = tl.program_id(0)
|
||||
|
||||
# get params
|
||||
batch_idx = tl.load(token_to_batch_map_ptr + kt_token_idx)
|
||||
kt_start_idx = tl.load(cum_kt_seq_lens_ptr + batch_idx)
|
||||
local_kt_token_idx = kt_token_idx - kt_start_idx
|
||||
|
||||
# get offsets
|
||||
hidden_size = num_kv_heads * dim_size * 2
|
||||
kt_states_base = kt_states_ptr + kt_token_idx * hidden_size
|
||||
block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block
|
||||
block_idx = tl.load(block_offsets_ptr + block_offset)
|
||||
token_idx_in_block = local_kt_token_idx % tokens_per_block
|
||||
cache_base = cache_ptr + (block_idx * tokens_per_block +
|
||||
token_idx_in_block) * hidden_size
|
||||
|
||||
# load kt cache
|
||||
for hidden_start in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = hidden_indices < hidden_size
|
||||
# load kt cache
|
||||
kt = tl.load(cache_base + hidden_indices, mask=mask, other=0.0)
|
||||
# store kt to kt_states
|
||||
tl.store(kt_states_base + hidden_indices, kt, mask=mask)
|
||||
|
||||
|
||||
def triton_update_kt_cache(k,
|
||||
kt_cache_tensor,
|
||||
kt_cache_block_offsets,
|
||||
seq_lens,
|
||||
kt_page_size,
|
||||
tokens_per_block,
|
||||
max_kt_blocks_per_seq,
|
||||
update=True):
|
||||
# inputs:
|
||||
# k: (total_seq_len, num_kv_heads, head_dim)
|
||||
# kt_cache_tensor: (num_blocks, tokens_per_block, num_kv_heads, 2 * head_dim)
|
||||
# kt_cache_block_offsets: (max_batch_size, max_kt_blocks_per_seq)
|
||||
# seq_lens: (batch_size)
|
||||
# kt_page_size: int
|
||||
# update: bool
|
||||
|
||||
# outputs:
|
||||
# kt_states: (total_kt_tokens, num_kv_heads, 2 * head_dim)
|
||||
|
||||
# params
|
||||
batch_size = seq_lens.size(0)
|
||||
num_kv_heads = k.size(1)
|
||||
head_dim = k.size(2)
|
||||
tokens_per_block = kt_cache_tensor.size(1)
|
||||
num_kt_tokens = (seq_lens + kt_page_size - 1) // kt_page_size
|
||||
|
||||
# context
|
||||
if not update:
|
||||
total_num_kt_tokens = num_kt_tokens.sum().item()
|
||||
cum_seq_lens = torch.cumsum(torch.cat([
|
||||
torch.zeros(1, device='cuda', dtype=torch.long),
|
||||
seq_lens.to(torch.long)
|
||||
]),
|
||||
dim=0)
|
||||
cum_kt_seq_lens = torch.cumsum(torch.cat([
|
||||
torch.zeros(1, device='cuda', dtype=torch.long),
|
||||
num_kt_tokens.to(torch.long)
|
||||
]),
|
||||
dim=0)
|
||||
|
||||
token_to_batch_map = torch.repeat_interleave(
|
||||
torch.arange(batch_size,
|
||||
device='cuda'), repeats=num_kt_tokens).to(torch.long)
|
||||
grid = (total_num_kt_tokens, )
|
||||
_update_kt_cache_ctx_kernel[grid](k,
|
||||
kt_cache_tensor,
|
||||
kt_cache_block_offsets,
|
||||
cum_seq_lens,
|
||||
cum_kt_seq_lens,
|
||||
token_to_batch_map,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
kt_page_size,
|
||||
tokens_per_block,
|
||||
max_kt_blocks_per_seq,
|
||||
BLOCK_SIZE=1024)
|
||||
return
|
||||
else:
|
||||
# generation
|
||||
# update kt cache
|
||||
grid = (batch_size, num_kv_heads)
|
||||
_update_kt_cache_gen_kernel[grid](k,
|
||||
kt_cache_tensor,
|
||||
kt_cache_block_offsets,
|
||||
seq_lens,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
kt_page_size,
|
||||
tokens_per_block,
|
||||
max_kt_blocks_per_seq,
|
||||
BLOCK_SIZE=1024)
|
||||
|
||||
# load kt cache
|
||||
total_num_kt_tokens = num_kt_tokens.sum().item()
|
||||
kt_states = torch.empty(
|
||||
(total_num_kt_tokens, num_kv_heads, 2 * head_dim),
|
||||
device='cuda',
|
||||
dtype=k.dtype)
|
||||
token_to_batch_map = torch.repeat_interleave(
|
||||
torch.arange(batch_size,
|
||||
device='cuda'), repeats=num_kt_tokens).to(torch.long)
|
||||
cum_kt_seq_lens = torch.cumsum(torch.cat([
|
||||
torch.zeros(1, device='cuda', dtype=torch.long),
|
||||
num_kt_tokens.to(torch.long)
|
||||
]),
|
||||
dim=0)
|
||||
grid = (total_num_kt_tokens, )
|
||||
_load_kt_cache_kernel[grid](kt_states,
|
||||
kt_cache_tensor,
|
||||
kt_cache_block_offsets,
|
||||
cum_kt_seq_lens,
|
||||
token_to_batch_map,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
tokens_per_block,
|
||||
max_kt_blocks_per_seq,
|
||||
BLOCK_SIZE=1024)
|
||||
|
||||
return kt_states
|
||||
1061
tensorrt_llm/_torch/attention_backend/sparse/rocket.py
Normal file
1061
tensorrt_llm/_torch/attention_backend/sparse/rocket.py
Normal file
File diff suppressed because it is too large
Load Diff
39
tensorrt_llm/_torch/attention_backend/sparse/utils.py
Normal file
39
tensorrt_llm/_torch/attention_backend/sparse/utils.py
Normal file
@ -0,0 +1,39 @@
|
||||
from .rocket import (RocketKVCacheManager, RocketTrtllmAttention,
|
||||
RocketVanillaAttention)
|
||||
|
||||
|
||||
def get_sparse_attn_kv_cache_manager(
|
||||
sparse_attn_config: "SparseAttentionConfig"):
|
||||
if sparse_attn_config.algorithm == "rocket":
|
||||
return RocketKVCacheManager
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported sparse attention algorithm: {sparse_attn_config.algorithm}"
|
||||
)
|
||||
|
||||
|
||||
def get_vanilla_sparse_attn_attention_backend(
|
||||
sparse_attn_config: "SparseAttentionConfig"):
|
||||
if sparse_attn_config.algorithm == "rocket":
|
||||
return RocketVanillaAttention
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported sparse attention algorithm in vanilla attention backend: {sparse_attn_config.algorithm}"
|
||||
)
|
||||
|
||||
|
||||
def get_trtllm_sparse_attn_attention_backend(
|
||||
sparse_attn_config: "SparseAttentionConfig"):
|
||||
if sparse_attn_config.algorithm == "rocket":
|
||||
return RocketTrtllmAttention
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported sparse attention algorithm in trtllm attention backend: {sparse_attn_config.algorithm}"
|
||||
)
|
||||
|
||||
|
||||
def get_flashinfer_sparse_attn_attention_backend(
|
||||
sparse_attn_config: "SparseAttentionConfig"):
|
||||
raise ValueError(
|
||||
f"Unsupported sparse attention algorithm in flashinfer attention backend: {sparse_attn_config.algorithm}"
|
||||
)
|
||||
@ -190,6 +190,10 @@ class TrtllmAttentionWrapper:
|
||||
spec_decoding_generation_lengths: Optional[torch.Tensor] = None,
|
||||
attention_sinks: Optional[torch.Tensor] = None,
|
||||
chunked_prefill_buffer_batch_size: int = 1,
|
||||
sparse_kv_indices: Optional[torch.Tensor] = None,
|
||||
sparse_kv_offsets: Optional[torch.Tensor] = None,
|
||||
sparse_attn_indices: Optional[torch.Tensor] = None,
|
||||
sparse_attn_offsets: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -229,6 +233,10 @@ class TrtllmAttentionWrapper:
|
||||
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
|
||||
attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU.
|
||||
chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens.
|
||||
sparse_kv_indices (torch.Tensor): The sparse indices for the KV cache, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
|
||||
sparse_kv_offsets (torch.Tensor): The batch offsets for the sparse KV indices, with shape of (num_contexts + 1) on GPU.
|
||||
sparse_attn_indices (torch.Tensor): The sparse indices for the attention layer, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
|
||||
sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU.
|
||||
"""
|
||||
self.layer_idx = layer_idx
|
||||
self.tokens_per_block = tokens_per_block
|
||||
@ -266,7 +274,10 @@ class TrtllmAttentionWrapper:
|
||||
self.softmax_stats_tensor = softmax_stats_tensor
|
||||
self.helix_position_offsets = helix_position_offsets
|
||||
self.attention_sinks = attention_sinks
|
||||
|
||||
self.sparse_kv_indices = sparse_kv_indices
|
||||
self.sparse_kv_offsets = sparse_kv_offsets
|
||||
self.sparse_attn_indices = sparse_attn_indices
|
||||
self.sparse_attn_offsets = sparse_attn_offsets
|
||||
if max_sequence_length > self.rope_params.max_positions:
|
||||
self.rope_params.max_positions = max_sequence_length
|
||||
self.rotary_inv_freq, self.rotary_cos_sin = self.rope_params.create_rope_const_params(
|
||||
@ -424,6 +435,12 @@ class TrtllmAttentionWrapper:
|
||||
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
|
||||
]
|
||||
mla_tensor_params = [self.helix_position_offsets]
|
||||
sparse_attention_params = [
|
||||
self.sparse_kv_indices,
|
||||
self.sparse_kv_offsets,
|
||||
self.sparse_attn_indices,
|
||||
self.sparse_attn_offsets,
|
||||
]
|
||||
|
||||
thop.attention(
|
||||
q,
|
||||
@ -491,6 +508,7 @@ class TrtllmAttentionWrapper:
|
||||
self.softmax_stats_tensor,
|
||||
spec_decoding_bool_params,
|
||||
spec_decoding_tensor_params,
|
||||
sparse_attention_params,
|
||||
)
|
||||
# reset the planned states (especially tensors) to avoid memory leak
|
||||
self.plan()
|
||||
@ -1239,6 +1257,14 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
use_paged_context_fmha=use_paged_context_fmha,
|
||||
is_mla_enable=self.is_mla_enable,
|
||||
)
|
||||
|
||||
sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None
|
||||
if self.sparse_attention_config is not None:
|
||||
sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(
|
||||
q, k, metadata)
|
||||
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
|
||||
q, k, metadata)
|
||||
|
||||
self.wrapper.plan(
|
||||
layer_idx=self.get_local_layer_idx(metadata),
|
||||
tokens_per_block=metadata.tokens_per_block,
|
||||
@ -1287,6 +1313,10 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
spec_decoding_generation_lengths,
|
||||
attention_sinks=attention_sinks,
|
||||
chunked_prefill_buffer_batch_size=chunked_prefill_buffer_batch_size,
|
||||
sparse_kv_indices=sparse_kv_indices,
|
||||
sparse_kv_offsets=sparse_kv_offsets,
|
||||
sparse_attn_indices=sparse_attn_indices,
|
||||
sparse_attn_offsets=sparse_attn_offsets,
|
||||
)
|
||||
out_dtype = None
|
||||
if out_scale is not None:
|
||||
@ -1500,3 +1530,27 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
self.num_heads,
|
||||
self.mla_params.v_head_dim,
|
||||
)
|
||||
|
||||
def sparse_attn_predict(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
**kwargs,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Predict sparse attn indices. It's implemented in the derived class.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sparse_kv_predict(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
**kwargs,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Predict sparse kv indices. It's implemented in the derived class.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -3,22 +3,33 @@ from typing import Optional, Type
|
||||
from ...models.modeling_utils import QuantConfig
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from .interface import AttentionBackend, MLAParams, PositionalEmbeddingParams
|
||||
from .sparse import (get_flashinfer_sparse_attn_attention_backend,
|
||||
get_trtllm_sparse_attn_attention_backend,
|
||||
get_vanilla_sparse_attn_attention_backend)
|
||||
from .trtllm import TrtllmAttention
|
||||
from .vanilla import VanillaAttention
|
||||
|
||||
|
||||
def get_attention_backend(backend_name: str) -> Type[AttentionBackend]:
|
||||
def get_attention_backend(
|
||||
backend_name: str,
|
||||
sparse_attn_config: Optional["SparseAttentionConfig"] = None
|
||||
) -> Type[AttentionBackend]:
|
||||
if backend_name == "VANILLA":
|
||||
if sparse_attn_config is not None:
|
||||
return get_vanilla_sparse_attn_attention_backend(sparse_attn_config)
|
||||
return VanillaAttention
|
||||
elif backend_name == "TRTLLM":
|
||||
if sparse_attn_config is not None:
|
||||
return get_trtllm_sparse_attn_attention_backend(sparse_attn_config)
|
||||
return TrtllmAttention
|
||||
elif backend_name == "FLASHINFER" and IS_FLASHINFER_AVAILABLE:
|
||||
from .flashinfer import FlashInferAttention
|
||||
|
||||
if sparse_attn_config is not None:
|
||||
return get_flashinfer_sparse_attn_attention_backend(
|
||||
sparse_attn_config)
|
||||
return FlashInferAttention
|
||||
elif backend_name == "FLASHINFER_STAR_ATTENTION" and IS_FLASHINFER_AVAILABLE:
|
||||
from .star_flashinfer import StarAttention
|
||||
|
||||
return StarAttention
|
||||
|
||||
return TrtllmAttention
|
||||
@ -42,12 +53,13 @@ def create_attention(
|
||||
predicted_tokens_per_seq: Optional[int] = 1,
|
||||
skip_create_weights_in_init: bool = False,
|
||||
attention_chunk_size: Optional[int] = None,
|
||||
sparse_attention_config: Optional["SparseAttentionConfig"] = None,
|
||||
):
|
||||
if attention_chunk_size is not None and backend_name.upper() != "TRTLLM":
|
||||
raise ValueError(
|
||||
f"Backend {backend_name} does not support chunked attention.")
|
||||
|
||||
attn_cls = get_attention_backend(backend_name)
|
||||
attn_cls = get_attention_backend(backend_name, sparse_attention_config)
|
||||
|
||||
if is_mla_enable:
|
||||
assert attn_cls.support_mla(
|
||||
@ -76,4 +88,5 @@ def create_attention(
|
||||
mla_params=mla_params,
|
||||
skip_create_weights_in_init=skip_create_weights_in_init,
|
||||
attention_chunk_size=attention_chunk_size,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@ except ImportError:
|
||||
|
||||
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
|
||||
PredefinedAttentionMask)
|
||||
from .sparse.kernel import triton_index_gather
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
@ -94,90 +95,157 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
||||
self.q_scaling = q_scaling
|
||||
|
||||
def _single_request_update_kv_cache(self, k, v, kv_cache_tensor, seq_len,
|
||||
cache_idx, cache_position):
|
||||
def _single_request_sparse_attn_predict(
|
||||
self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], kv_cache_tensor: torch.Tensor,
|
||||
metadata: AttentionMetadata, past_seen_token: int, sample_idx: int,
|
||||
**kwargs) -> tuple[Optional[torch.Tensor], int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _single_request_sparse_kv_predict(
|
||||
self, q: Optional[torch.Tensor], k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor], metadata: AttentionMetadata,
|
||||
past_seen_token: int, sample_idx: int,
|
||||
**kwargs) -> tuple[Optional[torch.Tensor], int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _single_request_update_kv_cache(self,
|
||||
k,
|
||||
v,
|
||||
kv_cache_tensor,
|
||||
past_seen_token,
|
||||
kv_len,
|
||||
cache_idx,
|
||||
sparse_kv_indices=None):
|
||||
# select tokens using the sparse kv indices
|
||||
if sparse_kv_indices is not None:
|
||||
k_selected = triton_index_gather(k, sparse_kv_indices)
|
||||
v_selected = triton_index_gather(v, sparse_kv_indices)
|
||||
else:
|
||||
k_selected, v_selected = k, v
|
||||
|
||||
# get cache position
|
||||
seq_len = past_seen_token + kv_len
|
||||
cache_position = torch.arange(past_seen_token,
|
||||
seq_len,
|
||||
device=kv_cache_tensor.device)
|
||||
|
||||
# get kv cache tensor
|
||||
k_out = kv_cache_tensor[cache_idx, 0, :, :, :].unsqueeze(0)
|
||||
v_out = kv_cache_tensor[cache_idx, 1, :, :, :].unsqueeze(0)
|
||||
|
||||
# update kv cache
|
||||
if k is not None and v is not None:
|
||||
access_type = self._access_type[k.dtype.itemsize]
|
||||
k_out.view(dtype=access_type).index_copy_(1, cache_position,
|
||||
k.view(dtype=access_type))
|
||||
v_out.view(dtype=access_type).index_copy_(1, cache_position,
|
||||
v.view(dtype=access_type))
|
||||
access_type = self._access_type[k_selected.dtype.itemsize]
|
||||
k_out.view(dtype=access_type).index_copy_(
|
||||
1, cache_position, k_selected.view(dtype=access_type))
|
||||
v_out.view(dtype=access_type).index_copy_(
|
||||
1, cache_position, v_selected.view(dtype=access_type))
|
||||
|
||||
return k_out[:, :seq_len, :, :], v_out[:, :seq_len, :, :]
|
||||
|
||||
def _single_request_forward(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask: AttentionMask,
|
||||
kv_cache_tensor,
|
||||
past_seen_token,
|
||||
cache_idx,
|
||||
attention_window_size: Optional[int] = None):
|
||||
# return past kv and the dense kv tensors for sparse attention
|
||||
if sparse_kv_indices is not None:
|
||||
k_states = torch.cat([k_out[:, :past_seen_token, :, :], k], dim=1)
|
||||
v_states = torch.cat([v_out[:, :past_seen_token, :, :], v], dim=1)
|
||||
else:
|
||||
k_states, v_states = k_out[:, :seq_len, :, :], v_out[:, :
|
||||
seq_len, :, :]
|
||||
return k_states, v_states
|
||||
|
||||
def _single_request_preprocess_inputs(self, q, k, v, kv_dtype):
|
||||
bsz = 1
|
||||
q_len = q.size(0)
|
||||
|
||||
# Query
|
||||
q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Key and Value
|
||||
target_seq_len = past_seen_token
|
||||
kv_len = 0
|
||||
if k is not None and v is not None:
|
||||
kv_len = k.size(0)
|
||||
k = k.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
|
||||
target_seq_len += kv_len
|
||||
|
||||
if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
|
||||
):
|
||||
qc = self.quant_config
|
||||
if qc.layer_quant_mode.has_fp8_kv_cache():
|
||||
assert kv_cache_tensor.dtype == torch.float8_e4m3fn, f"KV cache should have fp8 dtype, but get {kv_cache_tensor.dtype}"
|
||||
assert kv_dtype == torch.float8_e4m3fn, \
|
||||
f"KV cache should have fp8 dtype, but get {kv_dtype}"
|
||||
k = k.to(torch.float8_e4m3fn)
|
||||
v = v.to(torch.float8_e4m3fn)
|
||||
assert k.dtype == v.dtype == kv_cache_tensor.dtype, f"KV cache dtype {kv_cache_tensor.dtype} does not match k/v dtype {k.dtype}/{v.dtype}"
|
||||
assert k.dtype == v.dtype == kv_dtype, \
|
||||
f"KV cache dtype {kv_dtype} does not match k/v dtype {k.dtype}/{v.dtype}"
|
||||
|
||||
cache_position = torch.arange(past_seen_token,
|
||||
target_seq_len,
|
||||
device=q.device)
|
||||
return q, k, v, kv_len
|
||||
|
||||
key_states, value_states = self._single_request_update_kv_cache(
|
||||
k, v, kv_cache_tensor, target_seq_len, cache_idx, cache_position)
|
||||
def _single_request_create_attention_mask(self,
|
||||
attention_mask,
|
||||
past_seen_token,
|
||||
kv_len,
|
||||
q_device,
|
||||
q_len,
|
||||
attention_window_size=None):
|
||||
"""
|
||||
Create appropriate attention mask based on the attention type.
|
||||
|
||||
key_states = key_states.transpose(1, 2).to(q.dtype)
|
||||
value_states = value_states.transpose(1, 2).to(q.dtype)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# Attention Mask
|
||||
Returns:
|
||||
Tuple of (is_causal, attn_mask)
|
||||
"""
|
||||
bsz = 1
|
||||
is_causal = False
|
||||
attn_mask = None
|
||||
|
||||
# get cache position
|
||||
seq_len = past_seen_token + kv_len
|
||||
cache_position = torch.arange(past_seen_token, seq_len, device=q_device)
|
||||
|
||||
# create attention mask
|
||||
if attention_mask == PredefinedAttentionMask.CAUSAL:
|
||||
# Create custom sliding window mask as sdpa doesn't natively support it.
|
||||
if attention_window_size is not None:
|
||||
attn_mask = generate_sliding_window_mask(
|
||||
bsz, target_seq_len, cache_position, q.device,
|
||||
bsz, seq_len, cache_position, q_device,
|
||||
attention_window_size)
|
||||
elif past_seen_token == 0:
|
||||
is_causal = True
|
||||
elif q_len != 1:
|
||||
# attn_mask: 4-D tensor (batch_size, 1, query_seq_len, seq_len)
|
||||
attn_mask = generate_causal_mask(bsz, target_seq_len,
|
||||
cache_position, q.device)
|
||||
attn_mask = generate_causal_mask(bsz, seq_len, cache_position,
|
||||
q_device)
|
||||
elif attention_mask == PredefinedAttentionMask.FULL:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unexpected attention mask type")
|
||||
|
||||
return attn_mask, is_causal
|
||||
|
||||
def _single_request_attn_forward(self,
|
||||
q,
|
||||
key_states,
|
||||
value_states,
|
||||
is_causal,
|
||||
attn_mask,
|
||||
sparse_indices=None):
|
||||
"""
|
||||
Common attention computation using scaled dot-product attention.
|
||||
"""
|
||||
# select the key and value states using the sparse indices
|
||||
if sparse_indices is not None:
|
||||
key_states = triton_index_gather(key_states, sparse_indices)
|
||||
value_states = triton_index_gather(value_states, sparse_indices)
|
||||
|
||||
# transpose kv
|
||||
key_states = key_states.transpose(1, 2).to(q.dtype)
|
||||
value_states = value_states.transpose(1, 2).to(q.dtype)
|
||||
|
||||
# repeat kv to support MQA/GQA
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# get qk scale
|
||||
qk_scale = None
|
||||
if self.q_scaling is not None:
|
||||
qk_scale = 1 / (math.sqrt(self.head_dim) * self.q_scaling)
|
||||
|
||||
# attention
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
key_states,
|
||||
@ -187,21 +255,66 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
scale=qk_scale,
|
||||
)
|
||||
|
||||
attn_output = attn_output.squeeze(0)
|
||||
return attn_output
|
||||
|
||||
@staticmethod
|
||||
def _single_request_forward(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask: AttentionMask,
|
||||
kv_cache_tensor,
|
||||
past_seen_token,
|
||||
cache_idx,
|
||||
sample_idx,
|
||||
metadata: AttentionMetadata,
|
||||
attention_window_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
# preprocess inputs
|
||||
q, k, v, kv_len = self._single_request_preprocess_inputs(
|
||||
q, k, v, kv_cache_tensor.dtype)
|
||||
|
||||
# predict sparse kv indices
|
||||
sparse_kv_indices = None
|
||||
if self.sparse_attention_config is not None:
|
||||
sparse_kv_indices, kv_len = self._single_request_sparse_kv_predict(
|
||||
q, k, v, metadata, past_seen_token, sample_idx)
|
||||
|
||||
# update kv cache
|
||||
key_states, value_states = self._single_request_update_kv_cache(
|
||||
k, v, kv_cache_tensor, past_seen_token, kv_len, cache_idx,
|
||||
sparse_kv_indices)
|
||||
|
||||
# predict sparse attn indices
|
||||
sparse_indices = None
|
||||
if self.sparse_attention_config is not None:
|
||||
sparse_indices, kv_len = self._single_request_sparse_attn_predict(
|
||||
q, k, v, kv_cache_tensor, metadata, past_seen_token, sample_idx)
|
||||
|
||||
# create attention mask
|
||||
attn_mask, is_causal = self._single_request_create_attention_mask(
|
||||
attention_mask, past_seen_token, kv_len, q.device, q.size(2),
|
||||
attention_window_size)
|
||||
|
||||
# attention
|
||||
attn_output = self._single_request_attn_forward(q, key_states,
|
||||
value_states, is_causal,
|
||||
attn_mask,
|
||||
sparse_indices)
|
||||
|
||||
return attn_output.squeeze(0)
|
||||
|
||||
def no_kv_cache_forward(
|
||||
q: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor],
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
metadata: AttentionMetadata,
|
||||
*,
|
||||
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
v: Optional[torch.Tensor],
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
metadata: AttentionMetadata,
|
||||
*,
|
||||
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""
|
||||
This function is used to perform attention without kv cache.
|
||||
Args:
|
||||
@ -275,14 +388,14 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
# try to separate the kv cache estimation path from no kv cache attn.
|
||||
num_heads = self.num_heads
|
||||
num_kv_heads = self.num_kv_heads
|
||||
return VanillaAttention.no_kv_cache_forward(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
metadata=metadata,
|
||||
attention_mask=attention_mask)
|
||||
return self.no_kv_cache_forward(q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
metadata=metadata,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs)
|
||||
|
||||
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
|
||||
cache_indices = [
|
||||
@ -298,7 +411,7 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
offset = 0
|
||||
offset_kv = 0
|
||||
attn_outputs = []
|
||||
for i, (seq_len, seq_len_kv) in enumerate(
|
||||
for sample_idx, (seq_len, seq_len_kv) in enumerate(
|
||||
zip(metadata.seq_lens, metadata.seq_lens_kv)):
|
||||
single_q = q[offset:offset + seq_len]
|
||||
single_k = k[
|
||||
@ -306,13 +419,16 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
seq_len_kv] if k is not None and seq_len_kv != 0 else None
|
||||
single_v = v[
|
||||
offset_kv:offset_kv +
|
||||
seq_len_kv] if k is not None and seq_len_kv != 0 else None
|
||||
past_seen_token = past_seen_tokens[i]
|
||||
cache_idx = cache_indices[i]
|
||||
seq_len_kv] if v is not None and seq_len_kv != 0 else None
|
||||
|
||||
past_seen_token = past_seen_tokens[sample_idx]
|
||||
cache_idx = cache_indices[sample_idx]
|
||||
|
||||
attn_output = self._single_request_forward(
|
||||
single_q, single_k, single_v, attention_mask, kv_cache_tensor,
|
||||
past_seen_token, cache_idx, attention_window_size)
|
||||
past_seen_token, cache_idx, sample_idx, metadata,
|
||||
attention_window_size, **kwargs)
|
||||
|
||||
attn_outputs.append(attn_output)
|
||||
|
||||
offset += seq_len
|
||||
|
||||
@ -121,6 +121,7 @@ class ModelConfig(Generic[TConfig]):
|
||||
|
||||
spec_config: Optional["DecodingBaseConfig"] = None
|
||||
lora_config: Optional["LoraConfig"] = None
|
||||
sparse_attention_config: Optional["SparseAttentionConfig"] = None
|
||||
|
||||
is_generation: bool = True
|
||||
max_num_tokens: int = 8192
|
||||
|
||||
@ -259,7 +259,9 @@ class Attention(nn.Module):
|
||||
|
||||
self.quant_config = config.get_quant_config()
|
||||
self.attn_backend = config.attn_backend
|
||||
attn_cls = get_attention_backend(self.attn_backend)
|
||||
attn_cls = get_attention_backend(
|
||||
self.attn_backend,
|
||||
sparse_attn_config=config.sparse_attention_config)
|
||||
|
||||
# These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used,
|
||||
# but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora
|
||||
@ -277,6 +279,9 @@ class Attention(nn.Module):
|
||||
# Whether to fuse RoPE into the attention OP.
|
||||
# If true, RoPE will be applied in self.attn.forward.
|
||||
# If false, RoPE will be applied in self.apply_rope.
|
||||
if config.sparse_attention_config is not None:
|
||||
logger.warning("disable rope_fusion for sparse attention.")
|
||||
rope_fusion = False
|
||||
self.rope_fusion = rope_fusion
|
||||
if self.rope_fusion and not attn_cls.support_fused_rope():
|
||||
logger.warning(
|
||||
@ -314,6 +319,7 @@ class Attention(nn.Module):
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
q_scaling=self.q_scaling,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
sparse_attention_config=config.sparse_attention_config,
|
||||
)
|
||||
|
||||
self.support_fused_qkv = self.attn.support_fused_qkv()
|
||||
@ -854,6 +860,7 @@ class MLA(nn.Module):
|
||||
v_head_dim=self.v_head_dim,
|
||||
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
sparse_attention_config=config.sparse_attention_config,
|
||||
)
|
||||
|
||||
self.mqa = create_attention(
|
||||
@ -873,6 +880,7 @@ class MLA(nn.Module):
|
||||
v_head_dim=self.kv_lora_rank,
|
||||
predicted_tokens_per_seq=self.predicted_tokens_per_seq,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
sparse_attention_config=config.sparse_attention_config,
|
||||
)
|
||||
|
||||
self.aux_stream = aux_stream
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import random
|
||||
from collections.abc import Iterable
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@ -12,14 +11,15 @@ from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
|
||||
from tensorrt_llm.bindings.executor import DecodingMode
|
||||
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
|
||||
MTPDecodingConfig, PeftCacheConfig,
|
||||
SamplerType, SpeculativeConfig,
|
||||
TorchLlmArgs)
|
||||
SamplerType, SparseAttentionConfig,
|
||||
SpeculativeConfig, TorchLlmArgs)
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules)
|
||||
from tensorrt_llm.lora_manager import load_torch_lora
|
||||
from tensorrt_llm.mapping import CpType, Mapping
|
||||
|
||||
from ..attention_backend import get_sparse_attn_kv_cache_manager
|
||||
from ..model_config import ModelConfig
|
||||
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
|
||||
from .config import PyTorchConfig
|
||||
@ -42,6 +42,20 @@ from .seq_slot_manager import SeqSlotManager
|
||||
GB = 1 << 30
|
||||
|
||||
|
||||
def get_kv_cache_manager_cls(model_config: ModelConfig):
|
||||
config = model_config.pretrained_config
|
||||
sparse_attn_config = model_config.sparse_attention_config
|
||||
if is_mla(config):
|
||||
return KVCacheManager
|
||||
elif is_nemotron_hybrid(config):
|
||||
return MambaHybridCacheManager
|
||||
else:
|
||||
if sparse_attn_config is not None:
|
||||
return get_sparse_attn_kv_cache_manager(sparse_attn_config)
|
||||
else:
|
||||
return KVCacheManager
|
||||
|
||||
|
||||
class KvCacheCreator:
|
||||
"""Groups together logic related to KV cache construction."""
|
||||
|
||||
@ -61,6 +75,7 @@ class KvCacheCreator:
|
||||
kv_cache_config: KvCacheConfig,
|
||||
pytorch_backend_config: PyTorchConfig,
|
||||
speculative_config: SpeculativeConfig,
|
||||
sparse_attention_config: SparseAttentionConfig,
|
||||
):
|
||||
self._model_engine = model_engine
|
||||
self._draft_model_engine = draft_model_engine
|
||||
@ -72,50 +87,14 @@ class KvCacheCreator:
|
||||
self._kv_connector_manager = kv_connector_manager
|
||||
self._pytorch_backend_config = pytorch_backend_config
|
||||
self._speculative_config = speculative_config
|
||||
self._sparse_attention_config = sparse_attention_config
|
||||
self._tokens_per_block = tokens_per_block
|
||||
self._max_seq_len = max_seq_len
|
||||
self._max_batch_size = max_batch_size
|
||||
self._net_max_seq_len = net_max_seq_len
|
||||
self._dummy_reqs = None
|
||||
|
||||
@staticmethod
|
||||
def _get_cache_size_per_token(model_config: ModelConfig,
|
||||
mapping: Mapping) -> int:
|
||||
mem_per_token = 2
|
||||
quant_config = model_config.quant_config
|
||||
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
|
||||
):
|
||||
mem_per_token = 1
|
||||
|
||||
config = model_config.pretrained_config
|
||||
|
||||
num_key_value_heads = getattr(config, 'num_key_value_heads',
|
||||
config.num_attention_heads)
|
||||
if isinstance(num_key_value_heads, Iterable):
|
||||
num_key_value_heads = sum(num_key_value_heads) / len(
|
||||
num_key_value_heads)
|
||||
|
||||
mla = is_mla(config)
|
||||
tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size
|
||||
|
||||
kv_factor = 2
|
||||
if mla:
|
||||
# MLA has kv_lora_rank and qk_rope_head_dim
|
||||
head_dim = config.kv_lora_rank + config.qk_rope_head_dim
|
||||
kv_factor = 1
|
||||
else:
|
||||
_head_dim = getattr(config, 'head_dim', None)
|
||||
if not isinstance(_head_dim, int):
|
||||
_head_dim = config.hidden_size // config.num_attention_heads
|
||||
head_dim = _head_dim * num_key_value_heads // tp_size
|
||||
|
||||
# provide at least 1 layer to prevent division by zero cache size
|
||||
num_attention_layers = max(
|
||||
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
|
||||
mem_per_token *= num_attention_layers * head_dim
|
||||
# K and V
|
||||
mem_per_token *= kv_factor
|
||||
return mem_per_token
|
||||
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
|
||||
model_engine.model.model_config)
|
||||
|
||||
def _get_free_gpu_memory_fraction(self) -> float:
|
||||
fraction = self._kv_cache_config.free_gpu_memory_fraction
|
||||
@ -126,12 +105,14 @@ class KvCacheCreator:
|
||||
def _get_kv_size_per_token(self):
|
||||
model_config = self._model_engine.model.model_config
|
||||
mapping = self._mapping
|
||||
kv_size_per_token = self._get_cache_size_per_token(
|
||||
model_config, mapping)
|
||||
kv_size_per_token = self._kv_cache_manager_cls.get_cache_size_per_token(
|
||||
model_config, mapping, tokens_per_block=self._tokens_per_block)
|
||||
if self._draft_model_engine is not None:
|
||||
draft_model_config = self._draft_model_engine.model.model_config
|
||||
kv_size_per_token += self._get_cache_size_per_token(
|
||||
draft_model_config, mapping)
|
||||
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
|
||||
draft_model_config,
|
||||
mapping,
|
||||
tokens_per_block=self._tokens_per_block)
|
||||
return kv_size_per_token
|
||||
|
||||
def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
|
||||
@ -231,6 +212,12 @@ class KvCacheCreator:
|
||||
estimating_kv_cache = True
|
||||
self._kv_cache_config.max_tokens = self._get_token_num_for_estimation(
|
||||
)
|
||||
model_config = self._model_engine.model.model_config
|
||||
if model_config.attn_backend == "VANILLA":
|
||||
logger.info(
|
||||
"KV cache size estimation is not supported for Vanilla attention backend, disable it."
|
||||
)
|
||||
estimating_kv_cache = False
|
||||
return estimating_kv_cache
|
||||
|
||||
def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
|
||||
@ -374,6 +361,7 @@ class KvCacheCreator:
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
quant_config = model_engine.model.model_config.quant_config
|
||||
spec_config = self._speculative_config
|
||||
sparse_attn_config = self._sparse_attention_config
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
num_attention_heads = config.num_attention_heads
|
||||
@ -396,7 +384,7 @@ class KvCacheCreator:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
if is_mla(config):
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_manager = self._kv_cache_manager_cls(
|
||||
self._kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.
|
||||
SELFKONLY,
|
||||
@ -434,8 +422,7 @@ class KvCacheCreator:
|
||||
mamba_layer_mask = [
|
||||
char == "M" for char in config.hybrid_override_pattern
|
||||
]
|
||||
|
||||
kv_cache_manager = MambaHybridCacheManager(
|
||||
kv_cache_manager = self._kv_cache_manager_cls(
|
||||
# mamba cache parameters
|
||||
config.ssm_state_size,
|
||||
config.conv_kernel,
|
||||
@ -518,7 +505,7 @@ class KvCacheCreator:
|
||||
binding_model_config = model_engine.model.model_config.get_bindings_model_config(
|
||||
tokens_per_block=self._tokens_per_block) if is_vswa else None
|
||||
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_manager = self._kv_cache_manager_cls(
|
||||
self._kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_hidden_layers,
|
||||
@ -536,6 +523,7 @@ class KvCacheCreator:
|
||||
is_draft=model_engine.is_draft_model,
|
||||
kv_connector_manager=self._kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
sparse_attn_config=sparse_attn_config,
|
||||
)
|
||||
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
|
||||
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
|
||||
|
||||
@ -137,6 +137,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
|
||||
dist: Optional[MPIDist] = None,
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
sparse_attention_config: Optional["SparseAttentionConfig"] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
is_draft_model: bool = False,
|
||||
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
|
||||
@ -164,6 +165,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_config.max_draft_len = 0
|
||||
self.spec_config = spec_config
|
||||
self.is_spec_decode = spec_config is not None
|
||||
self.sparse_attention_config = sparse_attention_config
|
||||
self.enable_spec_decode = self.is_spec_decode
|
||||
self.is_draft_model = is_draft_model
|
||||
|
||||
@ -175,6 +177,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
pytorch_backend_config=pytorch_backend_config,
|
||||
mapping=self.mapping,
|
||||
spec_config=self.spec_config,
|
||||
sparse_attention_config=self.sparse_attention_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_seq_len=max_seq_len,
|
||||
lora_config=lora_config,
|
||||
@ -261,7 +264,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.is_warmup = False
|
||||
|
||||
self.attn_backend = get_attention_backend(
|
||||
pytorch_backend_config.attn_backend)
|
||||
pytorch_backend_config.attn_backend,
|
||||
sparse_attn_config=sparse_attention_config)
|
||||
|
||||
if self.is_spec_decode:
|
||||
self.spec_metadata = None
|
||||
@ -794,7 +798,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
enable_flash_mla=self.model.model_config.enable_flash_mla,
|
||||
enable_context_mla_with_cached_kv=
|
||||
enable_context_mla_with_cached_kv,
|
||||
cache_indirection=cache_indirection)
|
||||
cache_indirection=cache_indirection,
|
||||
sparse_attention_config=self.sparse_attention_config)
|
||||
|
||||
if self.attn_metadata is not None:
|
||||
# This assertion can be relaxed if needed: just create a new metadata
|
||||
@ -811,7 +816,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
runtime_features=self.attn_runtime_features,
|
||||
enable_flash_mla=self.model.model_config.enable_flash_mla,
|
||||
enable_context_mla_with_cached_kv=enable_context_mla_with_cached_kv,
|
||||
cache_indirection=cache_indirection)
|
||||
cache_indirection=cache_indirection,
|
||||
sparse_attention_config=self.sparse_attention_config)
|
||||
|
||||
return self.attn_metadata
|
||||
|
||||
|
||||
@ -160,6 +160,7 @@ class ModelLoader:
|
||||
pytorch_backend_config: PyTorchConfig,
|
||||
mapping: Mapping,
|
||||
spec_config: Optional["DecodingBaseConfig"],
|
||||
sparse_attention_config: Optional["SparseAttentionConfig"],
|
||||
max_num_tokens: int,
|
||||
max_seq_len: Optional[int],
|
||||
lora_config: Optional[LoraConfig] = None):
|
||||
@ -177,6 +178,7 @@ class ModelLoader:
|
||||
self.pytorch_backend_config = pytorch_backend_config
|
||||
self.mapping = mapping
|
||||
self.spec_config = spec_config
|
||||
self.sparse_attention_config = sparse_attention_config
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.max_seq_len = max_seq_len
|
||||
self.lora_config = lora_config
|
||||
@ -294,6 +296,7 @@ class ModelLoader:
|
||||
force_dynamic_quantization=self.pytorch_backend_config.
|
||||
force_dynamic_quantization,
|
||||
spec_config=self.spec_config,
|
||||
sparse_attention_config=self.sparse_attention_config,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_seq_len=self.max_seq_len,
|
||||
moe_max_num_tokens=self.pytorch_backend_config.moe_max_num_tokens,
|
||||
|
||||
@ -252,6 +252,8 @@ def create_py_executor(
|
||||
max_num_tokens = 8192
|
||||
|
||||
tokens_per_block = kv_cache_config.tokens_per_block
|
||||
if pytorch_backend_config.attn_backend == "VANILLA":
|
||||
tokens_per_block = max_num_tokens
|
||||
|
||||
if pytorch_backend_config.attn_backend in [
|
||||
"FLASHINFER", "FLASHINFER_STAR_ATTENTION"
|
||||
@ -308,6 +310,8 @@ def create_py_executor(
|
||||
has_draft_model_engine = spec_config.spec_dec_mode.has_draft_model()
|
||||
has_spec_drafter = spec_config.spec_dec_mode.has_spec_drafter()
|
||||
|
||||
sparse_attention_config = llm_args.sparse_attention_config
|
||||
|
||||
# chunk_unit_size may be changed to 64 when using flash mla
|
||||
attn_runtime_features = AttentionRuntimeFeatures(
|
||||
chunked_prefill=enable_chunked_context,
|
||||
@ -331,6 +335,7 @@ def create_py_executor(
|
||||
attn_runtime_features=attn_runtime_features,
|
||||
dist=dist,
|
||||
spec_config=spec_config,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
lora_config=lora_config,
|
||||
checkpoint_loader=checkpoint_loader,
|
||||
)
|
||||
@ -568,6 +573,7 @@ def create_py_executor(
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_backend_config,
|
||||
speculative_config=spec_config,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
)
|
||||
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
|
||||
with mem_monitor.observe_creation_stage(
|
||||
|
||||
@ -3,7 +3,7 @@ import enum
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -165,6 +165,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
max_beam_width: int = 1,
|
||||
is_draft: bool = False,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.mapping = mapping
|
||||
self.dtype = dtype
|
||||
@ -543,6 +544,64 @@ class KVCacheManager(BaseResourceManager):
|
||||
return get_size_in_bytes(cache_size // quant_vector_size,
|
||||
scaling_factor_dtype)
|
||||
|
||||
# TODO: refactor get_cache_size_per_token and get_cache_bytes_per_token to use the same logic
|
||||
@staticmethod
|
||||
def get_cache_size_per_token(model_config: ModelConfigPython,
|
||||
mapping: Mapping, **kwargs):
|
||||
# get kv cache dtype bytes
|
||||
mem_per_token = 2
|
||||
quant_config = model_config.quant_config
|
||||
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
|
||||
):
|
||||
mem_per_token = 1
|
||||
|
||||
# get num key value heads
|
||||
config = model_config.pretrained_config
|
||||
num_key_value_heads = getattr(config, 'num_key_value_heads',
|
||||
config.num_attention_heads)
|
||||
if isinstance(num_key_value_heads, Iterable):
|
||||
num_key_value_heads = sum(num_key_value_heads) / len(
|
||||
num_key_value_heads)
|
||||
|
||||
# get head dim
|
||||
mla = hasattr(config, "kv_lora_rank")
|
||||
if mla:
|
||||
head_dim = config.kv_lora_rank + config.qk_rope_head_dim
|
||||
kv_factor = 1
|
||||
else:
|
||||
tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if not isinstance(head_dim, int):
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
head_dim = head_dim * num_key_value_heads // tp_size
|
||||
kv_factor = 2
|
||||
|
||||
# provide at least 1 layer to prevent division by zero cache size
|
||||
num_attention_layers = max(
|
||||
len(mapping.pp_layers(model_config.get_num_attention_layers())), 1)
|
||||
mem_per_token *= num_attention_layers * head_dim
|
||||
|
||||
# K and V
|
||||
mem_per_token *= kv_factor
|
||||
return mem_per_token
|
||||
|
||||
def get_cache_bytes_per_token(self):
|
||||
cache_size_per_token = self.kv_factor * sum(
|
||||
self.num_kv_heads_per_layer) * self.head_dim
|
||||
|
||||
if self.dtype not in (DataType.FP8, DataType.HALF, DataType.BF16,
|
||||
DataType.FLOAT, DataType.NVFP4):
|
||||
raise ValueError(f'Cannot support {self.dtype} KV cache.')
|
||||
|
||||
cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token,
|
||||
self.dtype)
|
||||
if self.dtype == DataType.NVFP4:
|
||||
cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes(
|
||||
cache_size_per_token,
|
||||
quant_vector_size=16,
|
||||
scaling_factor_dtype=DataType.FP8)
|
||||
return cache_size_bytes_per_token
|
||||
|
||||
def calculate_max_num_blocks(self,
|
||||
kv_cache_config: KvCacheConfig,
|
||||
head_dim: int,
|
||||
@ -554,20 +613,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
if kv_cache_config.free_gpu_memory_fraction
|
||||
is not None else 0.9)
|
||||
|
||||
cache_size_per_token = kv_factor * sum(
|
||||
self.num_kv_heads_per_layer) * head_dim
|
||||
|
||||
if dtype not in (DataType.FP8, DataType.HALF, DataType.BF16,
|
||||
DataType.FLOAT, DataType.NVFP4):
|
||||
raise ValueError(f'Cannot support {dtype} KV cache.')
|
||||
|
||||
cache_size_bytes_per_token = get_size_in_bytes(cache_size_per_token,
|
||||
dtype)
|
||||
if dtype == DataType.NVFP4:
|
||||
cache_size_bytes_per_token += self.calculate_scaling_factor_size_bytes(
|
||||
cache_size_per_token,
|
||||
quant_vector_size=16,
|
||||
scaling_factor_dtype=DataType.FP8)
|
||||
cache_size_bytes_per_token = self.get_cache_bytes_per_token()
|
||||
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
|
||||
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
|
||||
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig,
|
||||
RocketSparseAttentionConfig,
|
||||
SaveHiddenStatesDecodingConfig, SchedulerConfig,
|
||||
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
|
||||
UserProvidedDecodingConfig)
|
||||
@ -61,4 +62,5 @@ __all__ = [
|
||||
'AttentionDpConfig',
|
||||
'LoRARequest',
|
||||
'SaveHiddenStatesDecodingConfig',
|
||||
'RocketSparseAttentionConfig',
|
||||
]
|
||||
|
||||
@ -166,6 +166,63 @@ class CudaGraphConfig(StrictBaseModel):
|
||||
return batch_sizes
|
||||
|
||||
|
||||
class BaseSparseAttentionConfig(StrictBaseModel):
|
||||
"""
|
||||
Configuration for sparse attention.
|
||||
"""
|
||||
algorithm: Literal["rocket"] = Field(
|
||||
default="rocket", description="The algorithm for sparse attention.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
# dispatch to the correct sparse attention config
|
||||
config_classes = {
|
||||
"rocket": RocketSparseAttentionConfig,
|
||||
}
|
||||
|
||||
algorithm = data.get("algorithm", None)
|
||||
if algorithm is None:
|
||||
raise ValueError(f"Sparse attention algorithm is required")
|
||||
|
||||
config_class = config_classes.get(algorithm.lower())
|
||||
if config_class is None:
|
||||
raise ValueError(f"Invalid algorithm: {algorithm}")
|
||||
|
||||
return config_class(**data)
|
||||
|
||||
def _check_fields(self):
|
||||
pass
|
||||
|
||||
def supports_backend(self, backend: str) -> bool:
|
||||
"""
|
||||
Override if the speculation algorithm does not support
|
||||
a subset of the possible backends.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
|
||||
"""
|
||||
Configuration for rocket sparse attention.
|
||||
"""
|
||||
window_size: Optional[int] = Field(
|
||||
default=None, description="The window size for snap KV.")
|
||||
kernel_size: Optional[int] = Field(
|
||||
default=None, description="The kernel size for snap KV.")
|
||||
topr: Optional[Union[int, float]] = Field(default=76, description="Top-r")
|
||||
topk: Optional[int] = Field(default=128, description="Top-k")
|
||||
prompt_budget: Optional[int] = Field(default=1266,
|
||||
description="Prompt budget")
|
||||
page_size: Optional[int] = Field(default=3, description="Page size")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
def supports_backend(self, backend: str) -> bool:
|
||||
return backend == "pytorch"
|
||||
|
||||
|
||||
class MoeConfig(StrictBaseModel):
|
||||
"""
|
||||
Configuration for MoE.
|
||||
@ -1133,6 +1190,10 @@ SpeculativeConfig: TypeAlias = Optional[Union[
|
||||
AutoDecodingConfig,
|
||||
]]
|
||||
|
||||
SparseAttentionConfig: TypeAlias = Union[
|
||||
RocketSparseAttentionConfig,
|
||||
]
|
||||
|
||||
|
||||
@PybindMirror.mirror_pybind_fields(_KvCacheConfig)
|
||||
class KvCacheConfig(StrictBaseModel, PybindMirror):
|
||||
@ -1510,6 +1571,12 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
description="Cache transceiver config.",
|
||||
status="prototype")
|
||||
|
||||
# Sparse attention config
|
||||
sparse_attention_config: Optional[SparseAttentionConfig] = Field(
|
||||
default=None,
|
||||
description="Sparse attention config.",
|
||||
status="prototype")
|
||||
|
||||
# Speculative decoding parameters
|
||||
speculative_config: SpeculativeConfig = Field(
|
||||
default=None, description="Speculative decoding config.")
|
||||
|
||||
80
tests/unittest/_torch/attention/sparse/test_rocketkv.py
Normal file
80
tests/unittest/_torch/attention/sparse/test_rocketkv.py
Normal file
@ -0,0 +1,80 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["pytorch"])
|
||||
@pytest.mark.parametrize("model_name",
|
||||
["llama-3.1-model/Llama-3.1-8B-Instruct"])
|
||||
@pytest.mark.parametrize("attention_backend", ["VANILLA", "TRTLLM"])
|
||||
def test_model(backend, model_name, attention_backend):
|
||||
model_dir = str(llm_models_root() / model_name)
|
||||
max_batch_size = 16
|
||||
max_output_tokens = 128
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7,
|
||||
enable_block_reuse=False)
|
||||
|
||||
sparse_attention_config = RocketSparseAttentionConfig(
|
||||
window_size=32,
|
||||
kernel_size=63,
|
||||
prompt_budget=2048,
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
backend=backend,
|
||||
kv_cache_config=kv_cache_config,
|
||||
attn_backend=attention_backend,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=8192,
|
||||
max_num_tokens=8192,
|
||||
cuda_graph_config=
|
||||
None, # sparse attention does not support cuda graph now
|
||||
)
|
||||
|
||||
inputs, references = [], []
|
||||
current_file = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(os.path.dirname(
|
||||
os.path.dirname(current_file)))
|
||||
input_file = f'{current_dir}/multi_gpu/test_star_attention_input.jsonl'
|
||||
with open(input_file, 'r') as f:
|
||||
for line in f:
|
||||
sample = json.loads(line)
|
||||
inputs.append({
|
||||
'prompt':
|
||||
sample['input_context'] + sample['input_query'],
|
||||
})
|
||||
references.append(sample['outputs'][0])
|
||||
|
||||
with llm:
|
||||
outputs = llm.generate(
|
||||
inputs,
|
||||
use_tqdm=True,
|
||||
sampling_params=SamplingParams(add_special_tokens=False,
|
||||
max_tokens=max_output_tokens,
|
||||
temperature=0.8,
|
||||
top_p=0.95),
|
||||
)
|
||||
|
||||
count = 0
|
||||
for ref, ret in zip(references, outputs):
|
||||
print(f"ret: {ret.outputs[0].text}")
|
||||
print(f"ref: {ref}")
|
||||
if ref not in ret.outputs[0].text:
|
||||
print(f'reference {ref} is not in the output {ret.outputs[0].text}')
|
||||
else:
|
||||
count = count + 1
|
||||
acc = count / len(outputs)
|
||||
|
||||
assert acc >= 0.9, 'accuracy test of rocketkv sparse attention failed'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "VANILLA")
|
||||
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "TRTLLM")
|
||||
@ -183,6 +183,10 @@ methods:
|
||||
annotation: Optional[Literal["rpc", "ray"]]
|
||||
default: null
|
||||
status: prototype
|
||||
sparse_attention_config:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.SparseAttentionConfig]
|
||||
default: null
|
||||
status: prototype
|
||||
return_annotation: None
|
||||
generate:
|
||||
parameters:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user