[None] [feat] Use triton kernels for RocketKV prediction module (#8682)

Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
heyuhhh 2025-11-14 10:51:09 +08:00 committed by GitHub
parent cc4c980e03
commit f07e9977c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 3527 additions and 680 deletions

View File

@ -20,7 +20,7 @@ namespace tensorrt_llm
{
namespace kernels
{
template <int THREADS_PER_BLOCK>
template <int THREADS_PER_BLOCK, int MAX_NUM_PAGES>
__global__ void gatherKvPageOffsetsKernel(
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
@ -32,23 +32,33 @@ __global__ void gatherKvPageOffsetsKernel(
// Each CUDA block processes one sequence from the batch for one head.
int32_t const head_idx = blockIdx.x;
int32_t const batch_idx = blockIdx.y;
int32_t const indices_block_size = sparse_params.sparse_attn_indices_block_size;
if (batch_idx >= batch_size)
{
return;
}
// Shared memory for reduction.
__shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
using BlockScan = cub::BlockScan<int32_t, THREADS_PER_BLOCK>;
using BlockReduce = cub::BlockReduce<Pair, THREADS_PER_BLOCK>;
__shared__ typename BlockScan::TempStorage temp_storage_scan;
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
__shared__ int32_t s_page_mask[MAX_NUM_PAGES];
__shared__ int32_t s_cu_page_mask[MAX_NUM_PAGES];
__shared__ int32_t s_scan_total; // Store total count from scan
// Get the range of sparse indices and the sequence length.
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size];
int32_t const num_sparse_pages = end_offset - start_offset;
int32_t const sparse_attn_indices_stride = sparse_params.sparse_attn_indices_stride;
int32_t const num_sparse_indices = end_offset - start_offset;
int32_t const original_seq_len = seq_lengths[batch_idx];
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
int32_t const page_loops = (ori_valid_pages + MAX_NUM_PAGES - 1) / MAX_NUM_PAGES;
// Get global sparse index.
int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
int32_t const sparse_idx_global = head_idx * sparse_attn_indices_stride + start_offset;
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
@ -58,56 +68,119 @@ __global__ void gatherKvPageOffsetsKernel(
int32_t local_max_page_index = -1;
int32_t local_num_valid_pages = 0;
// Perform the gather operation.
for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
int32_t src_page_idx_offset = 0;
int32_t dst_page_idx_offset = 0;
for (int32_t loop_idx = 0; loop_idx < page_loops; loop_idx++)
{
// Get the source idx and offset.
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
if (src_idx < 0)
src_page_idx_offset = loop_idx * MAX_NUM_PAGES;
int32_t loop_num_valid_pages = min(MAX_NUM_PAGES, ori_valid_pages - src_page_idx_offset);
for (int32_t i = threadIdx.x; i < MAX_NUM_PAGES; i += blockDim.x)
{
continue;
s_page_mask[i] = 0;
}
__syncthreads();
for (int32_t i = threadIdx.x; i < num_sparse_indices; i += blockDim.x)
{
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
int32_t const src_idx_start = src_idx * indices_block_size;
int32_t const src_idx_end = min(src_idx_start + indices_block_size, original_seq_len);
for (int32_t j = src_idx_start; j < src_idx_end; j++)
{
int32_t const src_page_idx = j / tokens_per_page;
if (src_page_idx >= src_page_idx_offset && src_page_idx < src_page_idx_offset + loop_num_valid_pages)
{
atomicExch(&s_page_mask[src_page_idx - src_page_idx_offset], 1);
}
}
}
__syncthreads();
// Handle case when loop_num_valid_pages > blockDim.x by processing in chunks
int32_t scan_offset = 0;
int32_t const scan_chunks = (loop_num_valid_pages + blockDim.x - 1) / blockDim.x;
for (int32_t chunk_idx = 0; chunk_idx < scan_chunks; chunk_idx++)
{
int32_t const chunk_start = chunk_idx * blockDim.x;
int32_t const chunk_size = min((int32_t) blockDim.x, loop_num_valid_pages - chunk_start);
int32_t thread_data = (threadIdx.x < chunk_size) ? s_page_mask[chunk_start + threadIdx.x] : 0;
int32_t thread_output;
int32_t aggregate;
BlockScan(temp_storage_scan).ExclusiveSum(thread_data, thread_output, aggregate);
__syncthreads();
if (threadIdx.x < chunk_size)
{
s_cu_page_mask[chunk_start + threadIdx.x] = thread_output + scan_offset;
}
__syncthreads();
// Update scan offset for next chunk
scan_offset += aggregate;
}
// Update the local max page index.
local_max_page_index = max(local_max_page_index, src_idx);
local_num_valid_pages++;
if (threadIdx.x == 0)
{
s_scan_total = scan_offset;
}
__syncthreads();
// Get the source and destination offsets.
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
// Perform the gather operation.
for (int32_t i = threadIdx.x; i < loop_num_valid_pages; i += blockDim.x)
{
// Skip if the page is not valid.
if (s_page_mask[i] == 0)
{
continue;
}
// Perform the gather operation: read from the sparse location and write to the dense location.
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
int32_t const src_idx = src_page_idx_offset + i;
int32_t const dst_idx = dst_page_idx_offset + s_cu_page_mask[i];
local_max_page_index = max(local_max_page_index, src_idx);
local_num_valid_pages++;
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + dst_idx;
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + dst_idx;
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
}
__syncthreads();
// Update dst offset using the total count from scan
dst_page_idx_offset += s_scan_total;
}
// Reduce the local max page indices and number of valid pages.
Pair local_pair = {local_max_page_index, local_num_valid_pages};
Pair result = cub::BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage).Reduce(local_pair, PairReduceOp());
Pair result = BlockReduce(temp_storage_reduce).Reduce(local_pair, PairReduceOp());
// Update sequence length for this head and batch.
if (threadIdx.x == 0)
{
int32_t const max_page_index = result.max_val;
int32_t const num_valid_pages = result.sum_val;
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
int32_t seq_len = 0;
if (num_valid_pages > 0)
{
int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
int32_t seq_len_remain = original_seq_len % tokens_per_page;
if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0)
if (max_page_index == ori_valid_pages - 1)
{
seq_len += tokens_per_page - seq_len_remain;
seq_len = (num_valid_pages - 1) * tokens_per_page
+ (original_seq_len - (ori_valid_pages - 1) * tokens_per_page);
}
else
{
seq_len = num_valid_pages * tokens_per_page;
}
output_seq_lengths[seq_len_offset] = seq_len;
}
else
{
output_seq_lengths[seq_len_offset] = 0;
}
output_seq_lengths[seq_len_offset] = seq_len;
}
}
@ -121,11 +194,8 @@ void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_
dim3 grid(num_head_kv, batch_size, 1);
// The block.
dim3 block(256, 1, 1);
// Shared memory size.
size_t smem_size = sizeof(Pair) * 256;
// Launch the kernel.
gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
gatherKvPageOffsetsKernel<256, 512><<<grid, block, 0, stream>>>(output_kv_page_offsets, output_seq_lengths,
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
}
} // namespace kernels

View File

@ -35,6 +35,9 @@ struct SparseAttentionParams
int32_t sparse_mla_topk{0}; // for DSA attention
void* sparse_mla_kv_cache_pool{nullptr}; // for DSA attention
int32_t sparse_attn_indices_block_size{1};
int32_t sparse_attn_indices_stride{0};
std::string toString() const
{
std::stringstream ss;
@ -43,7 +46,9 @@ struct SparseAttentionParams
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl
<< "sparse_mla_topk: " << this->sparse_mla_topk << std::endl
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl;
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl
<< "sparse_attn_indices_block_size: " << this->sparse_attn_indices_block_size << std::endl
<< "sparse_attn_indices_stride: " << this->sparse_attn_indices_stride << std::endl;
return ss.str();
}
};

View File

@ -64,10 +64,11 @@ void initBindings(nb::module_& m)
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_kv_indices") = std::nullopt,
nb::arg("sparse_kv_offsets") = std::nullopt, nb::arg("sparse_attn_indices") = std::nullopt,
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_mla_topk") = std::nullopt,
nb::arg("cu_q_seqlens") = std::nullopt, nb::arg("cu_kv_seqlens") = std::nullopt,
nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt,
nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt,
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_attn_indices_block_size"),
nb::arg("sparse_mla_topk") = std::nullopt, nb::arg("cu_q_seqlens") = std::nullopt,
nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt,
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
nb::call_guard<nb::gil_scoped_release>());
}
} // namespace tensorrt_llm::nanobind::thop

View File

@ -64,10 +64,11 @@ void initBindings(pybind11::module_& m)
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
py::arg("spec_decoding_tensor_params"), py::arg("sparse_kv_indices") = std::nullopt,
py::arg("sparse_kv_offsets") = std::nullopt, py::arg("sparse_attn_indices") = std::nullopt,
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_mla_topk") = std::nullopt,
py::arg("cu_q_seqlens") = std::nullopt, py::arg("cu_kv_seqlens") = std::nullopt,
py::arg("fmha_scheduler_counter") = std::nullopt, py::arg("mla_bmm1_scale") = std::nullopt,
py::arg("mla_bmm2_scale") = std::nullopt, py::arg("quant_q_buffer") = std::nullopt,
"Multi-head attention operation", py::call_guard<py::gil_scoped_release>());
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_attn_indices_block_size"),
py::arg("sparse_mla_topk") = std::nullopt, py::arg("cu_q_seqlens") = std::nullopt,
py::arg("cu_kv_seqlens") = std::nullopt, py::arg("fmha_scheduler_counter") = std::nullopt,
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
py::call_guard<py::gil_scoped_release>());
}
} // namespace tensorrt_llm::pybind::thop

View File

@ -86,10 +86,11 @@ public:
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
std::optional<torch::Tensor> quant_q_buffer) const
= 0;
};
@ -146,10 +147,11 @@ public:
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const override
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
std::optional<torch::Tensor> quant_q_buffer) const override
{
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
@ -395,6 +397,9 @@ public:
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_attn_offsets
= sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
op.mRuntimeSparseAttentionParams.sparse_attn_indices_block_size = sparse_attn_indices_block_size;
op.mRuntimeSparseAttentionParams.sparse_attn_indices_stride
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().size(-1) : 0;
if (op.isMLAEnabled() && op.mUseSparseAttention)
{
op.mRuntimeSparseAttentionParams.sparse_mla_topk = sparse_mla_topk;
@ -589,10 +594,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
std::optional<torch::Tensor> quant_q_buffer)
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> sparse_mla_topk,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer)
{
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
// Use these tensors to infer if the attention is using KV cache
@ -847,8 +852,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
quant_q_buffer);
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
}
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@ -866,8 +871,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
quant_q_buffer);
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
}
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);

View File

@ -63,9 +63,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
std::optional<torch::Tensor> quant_q_buffer);
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> sparse_mla_topk,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer);
} // namespace torch_ext

View File

@ -36,14 +36,16 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
constexpr int num_head_kv = 4;
constexpr int max_num_pages_per_seq = 8;
constexpr int tokens_per_page = 64;
constexpr int total_sparse_pages = max_batch_size * max_num_pages_per_seq; // Total sparse pages across all batches
// Batch 0 has 8 sparse tokens, Batch 1 has 6 sparse tokens, total = 14
constexpr int total_sparse_tokens = 14;
// Create input buffers
auto kv_page_offsets
= mBufferManager->gpu(ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
auto seq_lengths = mBufferManager->gpu(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
// Shape: [num_head_kv, total_sparse_tokens] - flattened across all batches
auto sparse_indices
= mBufferManager->gpu(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
= mBufferManager->gpu(ITensor::makeShape({num_head_kv, total_sparse_tokens}), nvinfer1::DataType::kINT32);
auto sparse_indices_offsets = mBufferManager->gpu(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
// Create output buffers
@ -57,7 +59,7 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
auto seq_lengths_host = mBufferManager->pinned(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
auto sparse_indices_host
= mBufferManager->pinned(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
= mBufferManager->pinned(ITensor::makeShape({num_head_kv, total_sparse_tokens}), nvinfer1::DataType::kINT32);
auto sparse_indices_offsets_host
= mBufferManager->pinned(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
@ -81,27 +83,43 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
}
// Initialize sequence lengths
seq_lengths_ptr[0] = 2 * tokens_per_page + 18; // 3 pages for batch 0
seq_lengths_ptr[1] = 3 * tokens_per_page + 3; // 4 pages for batch 1
seq_lengths_ptr[0] = 2 * tokens_per_page + 18; // 3 pages (146 tokens) for batch 0
seq_lengths_ptr[1] = 3 * tokens_per_page + 3; // 4 pages (195 tokens) for batch 1
// Initialize sparse indices with different patterns for different heads
// Shape: {total_sparse_pages, num_head_kv}
// Each head can have its own sparse pattern
int num_sparse_pages = 5;
int sparse_page_indices[5][4] = {{1, 0, 0, 1}, {2, 1, 1, -1}, {-1, 2, -1, -1}, {0, 1, 2, 3}, {3, 3, 3, -1}};
for (int page = 0; page < num_sparse_pages; ++page)
// Initialize sparse indices with token-level indices (indices_block_size = 1)
// Shape: [num_head_kv, total_sparse_tokens]
// All heads have the same number of sparse tokens: 8 for batch 0, 6 for batch 1
// Memory layout: sparse_indices_ptr[head_idx * total_sparse_tokens + token_offset]
std::vector<std::vector<int>> sparse_tokens_per_head
= {// Head 0: Batch 0 [10,20,70,75,90,95,100,105] -> pages [0,0,1,1,1,1,1,1] -> unique [0,1]
// Batch 1 [64,65,128,129,192,193] -> pages [1,1,2,2,3,3] -> unique [1,2,3]
{10, 20, 70, 75, 90, 95, 100, 105, 64, 65, 128, 129, 192, 193},
// Head 1: Batch 0 [5,6,65,66,130,131,135,140] -> pages [0,0,1,1,2,2,2,2] -> unique [0,1,2]
// Batch 1 [70,71,128,129,190,191] -> pages [1,1,2,2,2,2] -> unique [1,2]
{5, 6, 65, 66, 130, 131, 135, 140, 70, 71, 128, 129, 190, 191},
// Head 2: Batch 0 [20,21,80,81,85,86,90,91] -> pages [0,0,1,1,1,1,1,1] -> unique [0,1]
// Batch 1 [64,65,66,67,68,69] -> pages [1,1,1,1,1,1] -> unique [1]
{20, 21, 80, 81, 85, 86, 90, 91, 64, 65, 66, 67, 68, 69},
// Head 3: Batch 0 [70,71,72,73,74,75,76,77] -> pages [1,1,1,1,1,1,1,1] -> unique [1]
// Batch 1 [192,193,194,195,196,197] -> pages [3,3,3,3,3,3] -> unique [3]
{70, 71, 72, 73, 74, 75, 76, 77, 192, 193, 194, 195, 196, 197}};
// Fill sparse_indices_ptr using the defined data
for (int head = 0; head < num_head_kv; ++head)
{
for (int head = 0; head < num_head_kv; ++head)
for (int token_idx = 0; token_idx < total_sparse_tokens; ++token_idx)
{
int idx = head * num_sparse_pages + page;
sparse_indices_ptr[idx] = sparse_page_indices[page][head];
sparse_indices_ptr[head * total_sparse_tokens + token_idx] = sparse_tokens_per_head[head][token_idx];
}
}
// Initialize sparse indices offsets
sparse_indices_offsets_ptr[0] = 0; // Start of batch 0
sparse_indices_offsets_ptr[1] = 3; // Start of batch 1 (3 sparse pages for batch 0)
sparse_indices_offsets_ptr[2] = 5; // End (3 sparse pages for batch 1)
// Initialize sparse indices offsets (these are per-batch offsets into the flattened array)
sparse_indices_offsets_ptr[0] = 0; // Start of batch 0
sparse_indices_offsets_ptr[1] = 8; // Start of batch 1 (batch 0 has 8 sparse tokens)
sparse_indices_offsets_ptr[2] = 14; // End (batch 1 has 6 sparse tokens, total = 14)
// Copy data to GPU
mBufferManager->copy(*kv_page_offsets_host, *kv_page_offsets);
@ -112,6 +130,8 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
SparseAttentionParams sparse_params;
sparse_params.sparse_attn_indices = bufferCast<int32_t>(*sparse_indices);
sparse_params.sparse_attn_offsets = bufferCast<int32_t>(*sparse_indices_offsets);
sparse_params.sparse_attn_indices_block_size = 1; // Token-level indexing
sparse_params.sparse_attn_indices_stride = total_sparse_tokens;
// Launch the kernel
invokeGatherKvPageOffsets(bufferCast<int32_t>(*output_kv_page_offsets), bufferCast<int32_t>(*output_seq_lengths),
@ -136,21 +156,49 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
auto output_kv_offsets_ptr = bufferCast<int32_t>(*output_kv_page_offsets_host);
auto output_seq_len_ptr = bufferCast<int>(*output_seq_lengths_host);
// Verify sequence lengths for each head and batch
int expected_seq_lens[4][2] = {
{tokens_per_page + 18, tokens_per_page + 3}, // Head 0: batch 0 has 2 pages, batch 1 has 0 pages
{2 * tokens_per_page + 18, tokens_per_page + 3}, // Head 1: batch 0 has 3 pages, batch 1 has 0 pages
{2 * tokens_per_page, tokens_per_page + 3}, // Head 2: batch 0 has 2 pages, batch 1 has 0 pages
{tokens_per_page, 3} // Head 3: batch 0 has 2 pages, batch 1 has 0 pages
// Define expected results for each head and batch
// Format: {num_pages, {page_indices...}, seq_len}
struct ExpectedResult
{
int num_pages;
std::vector<int> page_indices;
int seq_len;
};
ExpectedResult expected_results[4][2] = {// Head 0
{ // Batch 0: tokens on pages [0,1] -> 2 pages, seq_len = 2 * 64 = 128
{2, {0, 1}, 2 * tokens_per_page},
// Batch 1: tokens on pages [1,2,3] -> 3 pages, max_page=3 (last page)
// seq_len = 195 - (4-3)*64 = 131 (no padding needed, max_page is last page)
{3, {1, 2, 3}, 131}},
// Head 1
{// Batch 0: tokens on pages [0,1,2] -> 3 pages (all), seq_len = 146
{3, {0, 1, 2}, 2 * tokens_per_page + 18},
// Batch 1: tokens on pages [1,2] -> 2 pages, max_page=2 (not last page)
// seq_len = 195 - (4-2)*64 = 67, padding: 67 + (64-3) = 128
{2, {1, 2}, 2 * tokens_per_page}},
// Head 2
{// Batch 0: tokens on pages [0,1] -> 2 pages, seq_len = 128
{2, {0, 1}, 2 * tokens_per_page},
// Batch 1: tokens on page [1] -> 1 page, max_page=1 (not last page)
// seq_len = 195 - (4-1)*64 = 3, padding: 3 + (64-3) = 64
{1, {1}, tokens_per_page}},
// Head 3
{// Batch 0: tokens on page [1] -> 1 page, seq_len = 64
{1, {1}, tokens_per_page},
// Batch 1: tokens on page [3] -> 1 page, max_page=3 (last page)
// seq_len = 195 - (4-1)*64 = 3 (no padding needed, max_page is last page)
{1, {3}, 3}}};
// Verify sequence lengths for each head and batch
for (int h = 0; h < num_head_kv; ++h)
{
for (int b = 0; b < batch_size; ++b)
{
int seq_len_idx = h * batch_size + b;
EXPECT_EQ(output_seq_len_ptr[seq_len_idx], expected_seq_lens[h][b])
<< "Sequence length mismatch at head=" << h << ", batch=" << b;
EXPECT_EQ(output_seq_len_ptr[seq_len_idx], expected_results[h][b].seq_len)
<< "Sequence length mismatch at head=" << h << ", batch=" << b
<< ", expected=" << expected_results[h][b].seq_len << ", got=" << output_seq_len_ptr[seq_len_idx];
}
}
@ -159,20 +207,13 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
{
for (int b = 0; b < batch_size; ++b)
{
int num_sparse_pages_batch = sparse_indices_offsets_ptr[b + 1] - sparse_indices_offsets_ptr[b];
auto const& expected = expected_results[h][b];
for (int d = 0; d < 2; ++d)
{
for (int p = 0; p < num_sparse_pages_batch; ++p)
for (int p = 0; p < expected.num_pages; ++p)
{
// Calculate expected value (from the sparse index)
int sparse_idx_global = sparse_indices_offsets_ptr[b] + p;
int src_page_idx
= sparse_indices_ptr[h * sparse_indices_offsets_ptr[batch_size] + sparse_idx_global];
if (src_page_idx == -1)
{
continue; // Skip invalid indices
}
int src_page_idx = expected.page_indices[p];
// Calculate output offset
size_t output_offset = h * batch_size * 2 * max_num_pages_per_seq + b * 2 * max_num_pages_per_seq
@ -181,7 +222,9 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
int expected_value = 1000 + b * 100 + d * 10 + src_page_idx;
EXPECT_EQ(output_kv_offsets_ptr[output_offset], expected_value)
<< "Mismatch at head=" << h << ", batch=" << b << ", dim=" << d << ", page=" << p;
<< "KV page offset mismatch at head=" << h << ", batch=" << b << ", dim=" << d << ", page=" << p
<< ", expected_page_idx=" << src_page_idx << ", expected_value=" << expected_value
<< ", got=" << output_kv_offsets_ptr[output_offset];
}
}
}

View File

@ -6,10 +6,11 @@ This example demonstrates how to use sparse attention with TensorRT-LLM.
Supported sparse attention algorithms:
- RocketKV
- DSA
Usage:
```bash
python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
python llm_sparse_attention.py --algo ROCKETKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
```
"""
import argparse
@ -70,7 +71,7 @@ def parse_arguments():
help="The maximum chunk size for the indexer.")
parser.add_argument("--max_seq_len",
type=int,
default=8192,
default=10240,
help="The maximum sequence length.")
parser.add_argument("--max_batch_size",
type=int,
@ -83,7 +84,7 @@ def parse_arguments():
parser.add_argument(
"--max_num_tokens",
type=int,
default=8192,
default=81920,
help=
"The maximum total tokens (context + generation) across all sequences in a batch."
)
@ -104,7 +105,7 @@ def parse_arguments():
# KV cache
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
parser.add_argument("--kv_cache_fraction", type=float, default=None)
parser.add_argument("--kv_cache_fraction", type=float, default=0.7)
parser.add_argument('--num_samples', type=int, default=10)
# Runtime

View File

@ -356,7 +356,6 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
cuda_graph_config=cuda_graph_config,
torch_compile_config=None,
print_iter_log=args.print_iter_log,
moe_config=MoeConfig(backend=args.moe_backend),
)

View File

@ -28,7 +28,8 @@ from transformers import AutoTokenizer
# Add tensorrt_llm imports
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
RocketSparseAttentionConfig)
from tensorrt_llm.logger import logger
# Chat templates mapping
@ -362,6 +363,10 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
sparse_attention_config = None
logger.info("Using standard attention")
cuda_graph_config = CudaGraphConfig(
max_batch_size=args.max_batch_size
) if args.attention_backend == "TRTLLM" else None
# Initialize LLM
llm = LLM(
model=args.model_path,
@ -372,8 +377,7 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
tensor_parallel_size=args.tensor_parallel_size,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
cuda_graph_config=None,
torch_compile_config=None,
cuda_graph_config=cuda_graph_config,
)
# Initialize tokenizer

File diff suppressed because it is too large Load Diff

View File

@ -23,7 +23,13 @@ from tensorrt_llm.bindings.internal.batch_manager import \
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from .kernel import triton_index_gather, triton_update_kt_cache
from .kernel import (triton_bmm, triton_flatten_to_batch, triton_index_gather,
triton_rocket_batch_to_flatten,
triton_rocket_paged_kt_cache_bmm, triton_rocket_qk_split,
triton_rocket_reduce_scores,
triton_rocket_update_kt_cache_ctx,
triton_rocket_update_kt_cache_gen, triton_softmax,
triton_topk)
ModelConfig = tensorrt_llm.bindings.ModelConfig
@ -35,8 +41,77 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
if self.sparse_attention_config is None:
raise ValueError("Sparse attention config is not set")
self.prompt_budget = self.sparse_attention_config.prompt_budget
self.window_size = self.sparse_attention_config.window_size
self.page_size = self.sparse_attention_config.page_size
self.topk = self.sparse_attention_config.topk
capture_graph = torch.cuda.is_current_stream_capturing()
# Cumulative valid sequence lengths for query and key
self.q_cu_seqlens_cuda = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_sequences + 1, ),
dtype=torch.int32,
cache_name="q_cu_seqlens_cuda",
capture_graph=capture_graph,
)
self.q_cu_seqlens = torch.zeros_like(self.q_cu_seqlens_cuda,
device='cpu',
dtype=torch.int32)
self.k_cu_seqlens_cuda = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_sequences + 1, ),
dtype=torch.int32,
cache_name="k_cu_seqlens_cuda",
capture_graph=capture_graph,
)
self.k_cu_seqlens = torch.zeros_like(self.k_cu_seqlens_cuda,
device='cpu',
dtype=torch.int32)
# Context length of RocketKV key for each valid sequence
self.k_context_lens = torch.empty(
self.max_num_sequences,
device='cpu',
dtype=torch.int32,
)
# Cumulative context lengths for each sequence
self.context_cumsum_cuda = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_sequences + 1, ),
dtype=torch.int32,
cache_name="context_cumsum_cuda",
capture_graph=capture_graph,
)
self.context_cumsum = torch.zeros_like(self.context_cumsum_cuda,
device='cpu',
dtype=torch.int32)
# Sparse kv indices offsets for each sequence in context phase
self.sparse_offsets_ctx_cuda = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_sequences + 1, ),
dtype=torch.int32,
cache_name="sparse_offsets_ctx_cuda",
capture_graph=capture_graph,
)
self.sparse_offsets_ctx = torch.zeros_like(self.sparse_offsets_ctx_cuda,
device='cpu',
dtype=torch.int32)
# Valid sequence indices used in sparse kv indices prediction
self.valid_seq_indices_cuda = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_sequences, ),
dtype=torch.int32,
cache_name="valid_seq_indices_cuda",
capture_graph=capture_graph,
)
# KT cache block offsets used in KT cache related kernels
self.kt_cache_block_offsets = self.get_empty(
self.cuda_graph_buffers,
[
@ -54,6 +129,41 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
pin_memory=True,
)
# Number of KT tokens for each sequence
self.num_kt_tokens = torch.empty(
self.max_num_sequences,
device='cpu',
dtype=torch.int32,
)
# Cumulative KT lengths for each sequence
self.cum_kt_lens_cuda = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_sequences + 1, ),
dtype=torch.int32,
cache_name="cum_kt_lens_cuda",
capture_graph=capture_graph,
)
self.cum_kt_lens = torch.zeros_like(self.cum_kt_lens_cuda,
device='cpu',
dtype=torch.int32)
# Sparse attn indices offsets for each sequence in generation phase
self.sparse_offsets_gen_cuda = self.get_empty(
self.cuda_graph_buffers,
(self.max_num_sequences + 1, ),
dtype=torch.int32,
cache_name="sparse_offsets_gen_cuda",
capture_graph=capture_graph,
)
self.sparse_offsets_gen = torch.zeros_like(self.sparse_offsets_gen_cuda,
device='cpu',
dtype=torch.int32)
# Maximum number of KT tokens
self.max_kt_tokens = (self.max_seq_len + self.page_size -
1) // self.page_size
@property
def kt_tokens_per_block(self) -> Optional[int]:
"""
@ -100,81 +210,85 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
self.host_kt_cache_block_offsets[:self.num_seqs],
non_blocking=True)
# -------------------------------- Context phase --------------------------------
self.context_cumsum[1:self.num_contexts + 1] = torch.cumsum(
self.prompt_lens_cpu[:self.num_contexts], dim=0)
self.context_cumsum_cuda[:self.num_contexts + 1].copy_(
self.context_cumsum[:self.num_contexts + 1], non_blocking=True)
@torch.compile(dynamic=True)
def convert_token_to_page_sparse_indices(
sparse_attn_indices: torch.Tensor, sparse_attn_offsets: torch.Tensor,
metadata: 'TrtllmAttentionMetadata'
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert token-based sparse attention indices to page-based sparse attention indices.
# We need to filter out sequences that are too short to skip sparse kv indices prediction
valid_mask = self.prompt_lens_cpu[:self.
num_contexts] >= self.prompt_budget
valid_seq_indices = torch.where(valid_mask)[0]
invalid_seq_indices = torch.where(~valid_mask)[0]
valid_batch_size = len(valid_seq_indices)
self.valid_seq_indices_cuda[:valid_batch_size].copy_(valid_seq_indices,
non_blocking=True)
Args:
sparse_attn_indices: Token-based indices with shape [num_tokens, num_kv_heads]
sparse_attn_offsets: Offsets with shape [batch_size+1] indicating token boundaries for each batch
metadata: Attention metadata containing tokens_per_block (page_size)
# Only consider sequences that are long enough for sparse kv indices prediction in context phase
self.k_context_lens[:valid_batch_size] = self.prompt_lens_cpu[
valid_seq_indices] - self.window_size
if valid_batch_size > 0:
# Maximum context length of RocketKV key for valid sequences for padding
self.max_rocket_k_ctx_len = self.k_context_lens[:
valid_batch_size].max(
).item()
else:
self.max_rocket_k_ctx_len = 0
Returns:
Tuple of (page_indices, page_offsets):
- page_indices: Page-based indices with shape [num_pages, num_kv_heads]
- page_offsets: Updated offsets with shape [batch_size+1] indicating page boundaries for each batch
sparse_counts_ctx = torch.zeros(self.num_contexts,
dtype=torch.int32,
device='cpu')
sparse_counts_ctx[valid_seq_indices] = self.prompt_budget
sparse_counts_ctx[invalid_seq_indices] = self.prompt_lens_cpu[
invalid_seq_indices]
Example:
If sparse_attn_indices first dimension is [1, 30, 67] and page_size=32,
the result will be [0, 2] (token 1 -> page 0, token 30 -> page 0, token 67 -> page 2)
"""
page_size = metadata.tokens_per_block
batch_size = sparse_attn_offsets.size(0) - 1
num_kv_heads = sparse_attn_indices.size(1)
self.sparse_offsets_ctx[1:self.num_contexts + 1] = torch.cumsum(
sparse_counts_ctx, dim=0)
self.sparse_offsets_ctx_cuda[:self.num_contexts + 1].copy_(
self.sparse_offsets_ctx[:self.num_contexts + 1], non_blocking=True)
# Convert token indices to page indices
page_indices = sparse_attn_indices // page_size
self.q_cu_seqlens[:valid_batch_size + 1] = torch.arange(
valid_batch_size + 1, device='cpu',
dtype=torch.int32) * self.window_size
self.q_cu_seqlens_cuda[:valid_batch_size + 1].copy_(
self.q_cu_seqlens[:valid_batch_size + 1], non_blocking=True)
# Process each batch and each kv_head separately to remove duplicates
new_page_indices_list = []
new_offsets = torch.zeros_like(sparse_attn_offsets)
self.k_cu_seqlens[1:valid_batch_size + 1] = torch.cumsum(
self.k_context_lens[:valid_batch_size], dim=0)
self.k_cu_seqlens_cuda[:valid_batch_size + 1].copy_(
self.k_cu_seqlens[:valid_batch_size + 1], non_blocking=True)
current_offset = 0
for batch_idx in range(batch_size):
start_idx = sparse_attn_offsets[batch_idx]
end_idx = sparse_attn_offsets[batch_idx + 1]
self.valid_batch_size = valid_batch_size
self.total_sparse_ctx_indices = self.sparse_offsets_ctx[
self.num_contexts].item()
if start_idx >= end_idx:
# Empty batch
new_offsets[batch_idx + 1] = current_offset
continue
# -------------------------------- Generation phase --------------------------------
self.num_kt_tokens[:self.num_generations] = (
self.kv_lens[self.num_contexts:self.num_seqs] + self.page_size -
1) // self.page_size
batch_page_indices = page_indices[
start_idx:end_idx] # [num_tokens_in_batch, num_kv_heads]
self.cum_kt_lens[1:self.num_generations + 1] = torch.cumsum(
self.num_kt_tokens[:self.num_generations], dim=0)
self.cum_kt_lens_cuda[:self.num_generations + 1].copy_(
self.cum_kt_lens[:self.num_generations + 1], non_blocking=True)
# For each kv_head, remove duplicates while preserving order
batch_unique_pages = []
for head_idx in range(num_kv_heads):
head_pages = batch_page_indices[:, head_idx]
unique_pages = torch.unique(head_pages, sorted=False)
batch_unique_pages.append(unique_pages)
self.total_kt_tokens = self.num_generations * self.max_kt_tokens
# Find the maximum length among all heads for this batch
max_len = max(pages.size(0) for pages in batch_unique_pages)
topk_tensor = torch.tensor(self.topk, dtype=torch.int32)
if max_len > 0:
batch_result = torch.full((max_len, num_kv_heads),
fill_value=-1,
dtype=page_indices.dtype,
device=page_indices.device)
# Some sequences may have less than topk KT tokens
# We need to use the minimum of topk and the number of KT tokens
sparse_counts_gen = torch.minimum(
topk_tensor, self.num_kt_tokens[:self.num_generations])
for head_idx in range(num_kv_heads):
unique_pages = batch_unique_pages[head_idx]
batch_result[:unique_pages.size(0), head_idx] = unique_pages
self.sparse_offsets_gen[1:self.num_generations + 1] = torch.cumsum(
sparse_counts_gen[:self.num_generations], dim=0)
self.sparse_offsets_gen_cuda[:self.num_generations + 1].copy_(
self.sparse_offsets_gen[:self.num_generations + 1],
non_blocking=True)
new_page_indices_list.append(batch_result)
current_offset += max_len
new_offsets[batch_idx + 1] = current_offset
new_page_indices = torch.cat(new_page_indices_list, dim=0)
return new_page_indices, new_offsets
self.total_sparse_gen_indices = self.topk * self.num_generations
class RocketTrtllmAttention(TrtllmAttention):
@ -213,85 +327,6 @@ class RocketTrtllmAttention(TrtllmAttention):
self.kernel_size = sparse_attention_config.kernel_size
self.page_size = sparse_attention_config.page_size
def sparse_attn_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse attention indices.
For RocketKV:
- Generation phase: predict RocketKV sparse attention indices
Returns:
- sparse_attn_indices: [total_selected_indices, num_kv_heads]
- sparse_attn_offsets: [batch_size + 1] with cumulative indices count
"""
if k is None:
q, k, _ = q.split([
self.num_heads * self.head_dim, self.num_kv_heads *
self.head_dim, self.num_kv_heads * self.head_dim
],
dim=-1)
num_contexts = metadata.num_contexts
num_generations = metadata.num_generations
seq_lens = metadata.seq_lens
seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
sparse_attn_indices = []
sparse_attn_offsets = [0]
q_offset = 0
k_offset = 0
for i in range(num_contexts + num_generations):
seq_len = seq_lens[i].item()
seq_len_kv = seq_lens_kv[i].item()
if seq_len <= 0 or seq_len_kv <= 0:
assert False, "Invalid sequence length"
single_q = q[q_offset:q_offset + seq_len]
single_k = k[k_offset:k_offset + seq_len_kv]
single_q = single_q.view(1, seq_len, self.num_heads,
self.head_dim).transpose(1, 2)
single_k = single_k.view(1, seq_len_kv, self.num_kv_heads,
self.head_dim)
past_seen_token = past_seen_tokens[i]
# Generation phase: RocketKV sparse attention indices
if i >= num_contexts:
_sparse_attn_indices = self._rocketkv_selection(
single_q, single_k, past_seen_token, metadata, i)
if _sparse_attn_indices is not None:
sparse_attn_indices.append(
_sparse_attn_indices.squeeze(0)) # [topk, num_kv_heads]
sparse_attn_offsets.append(sparse_attn_offsets[-1] +
_sparse_attn_indices.size(1))
else:
sparse_attn_offsets.append(sparse_attn_offsets[-1])
q_offset += seq_len
k_offset += seq_len_kv
if len(sparse_attn_indices) == 0:
sparse_attn_indices, sparse_attn_offsets = None, None
else:
sparse_attn_indices = torch.cat(sparse_attn_indices,
dim=0).to(torch.int32)
sparse_attn_offsets = torch.tensor(sparse_attn_offsets,
dtype=torch.int32).to(q.device)
sparse_attn_indices, sparse_attn_offsets = convert_token_to_page_sparse_indices(
sparse_attn_indices, sparse_attn_offsets, metadata)
sparse_attn_indices = sparse_attn_indices.transpose(0,
1).contiguous()
return sparse_attn_indices, sparse_attn_offsets
def sparse_kv_predict(
self,
q: torch.Tensor,
@ -300,229 +335,203 @@ class RocketTrtllmAttention(TrtllmAttention):
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse kv indices.
Predict sparse KV indices using optimized SnapKV algorithm.
For RocketKV:
- Context phase: predict RocketKV sparse kv indices
Returns:
- flattened_indices: [total_selected_indices, num_kv_heads]
- batch_offsets: [batch_size + 1] with cumulative indices count
Uses a single Triton kernel to compute attention scores between observation window
and prefix tokens, then selects the most important prefix tokens directly.
"""
num_ctx_tokens = metadata.num_ctx_tokens
if num_ctx_tokens == 0:
return None, None
if k is None:
q, k, _ = q.split([
self.num_heads * self.head_dim, self.num_kv_heads *
self.head_dim, self.num_kv_heads * self.head_dim
],
dim=-1)
num_contexts = metadata.num_contexts
num_generations = metadata.num_generations
seq_lens = metadata.seq_lens
seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
sparse_kv_indices = []
sparse_kv_offsets = [0]
q_offset = 0
k_offset = 0
for i in range(num_contexts + num_generations):
seq_len = seq_lens[i].item()
seq_len_kv = seq_lens_kv[i].item()
if seq_len <= 0 or seq_len_kv <= 0:
assert False, "Invalid sequence length"
single_q = q[q_offset:q_offset + seq_len]
single_k = k[k_offset:k_offset + seq_len_kv]
single_q = single_q.view(1, seq_len, self.num_heads,
self.head_dim).transpose(1, 2)
single_k = single_k.view(1, seq_len_kv, self.num_kv_heads,
self.head_dim)
past_seen_token = past_seen_tokens[i]
if i < num_contexts:
# Context phase: SnapKV sparse kv indices
_sparse_kv_indices = self._get_snapkv_indices(
single_q, single_k, past_seen_token, metadata, i)
if _sparse_kv_indices is not None:
sparse_kv_indices.append(
_sparse_kv_indices.squeeze(0)) # [budget, num_kv_heads]
sparse_kv_offsets.append(sparse_kv_offsets[-1] +
_sparse_kv_indices.size(1))
else:
sparse_kv_offsets.append(sparse_kv_offsets[-1])
q_offset += seq_len
k_offset += seq_len_kv
if len(sparse_kv_indices) == 0:
sparse_kv_indices, sparse_kv_offsets = None, None
qkv_input = q[:num_ctx_tokens]
else:
sparse_kv_indices = torch.cat(sparse_kv_indices,
dim=0).to(torch.int32)
sparse_kv_indices = sparse_kv_indices.transpose(0, 1).contiguous()
sparse_kv_offsets = torch.tensor(sparse_kv_offsets,
dtype=torch.int32).to(q.device)
return sparse_kv_indices, sparse_kv_offsets
qkv_input = torch.cat([q, k], dim=1)
def _get_snapkv_indices(self, q: Tensor, k: Tensor, past_seen_token: int,
metadata: RocketTrtllmAttentionMetadata,
sample_idx: int) -> Optional[Tensor]:
"""
Get SnapKV sparse kv indices from the input sequence for context phase.
The shape of output is (1, prompt_budget, num_kv_heads)
"""
bsz = 1
seq_len = k.size(1) # k shape: (1, seq_len, num_kv_heads, head_dim)
if metadata.valid_batch_size > 0:
q_window, k_context = triton_rocket_qk_split(
qkv_input,
metadata.prompt_lens_cuda,
metadata.context_cumsum_cuda,
metadata.valid_seq_indices_cuda,
metadata.k_cu_seqlens_cuda,
self.num_heads,
self.num_kv_heads,
self.head_dim,
self.window_size,
metadata.valid_batch_size,
)
# If the sequence length is less than the prompt budget, do not enable sparse kv cache
if seq_len <= self.prompt_budget:
return None
scores = triton_bmm(q_window,
k_context,
metadata.q_cu_seqlens_cuda,
metadata.k_cu_seqlens_cuda,
metadata.valid_batch_size,
causal=False)
# Use last window_size tokens as observation
# (1, num_heads, window_size, head_dim)
q_obs = q[:, :, -self.window_size:]
# (1, num_kv_heads, seq_len, head_dim)
k_pre = repeat_kv(k.transpose(1, 2),
self.num_heads // self.num_kv_heads)
scores = triton_softmax(scores, metadata.k_cu_seqlens_cuda,
metadata.valid_batch_size)
dist = (torch.arange(0, self.window_size, device=q.device)[:, None] -
torch.arange(0, seq_len, device=q.device)[None, :] + seq_len -
self.window_size)
attention_mask = (dist >= 0)
# scores: [num_heads, window_size, total_k_tokens] -> [num_kv_heads, total_k_tokens]
scores = scores.view(self.num_kv_heads,
self.num_heads // self.num_kv_heads,
self.window_size, -1).sum(dim=(1, 2))
score = torch.matmul(q_obs, k_pre.transpose(-1, -2)) / math.sqrt(
self.head_dim)
# Reshape scores to handle variable length sequences with padding using Triton
# scores: [num_kv_heads, total_k_tokens] -> [valid_batch_size, num_kv_heads, padding_size]
scores = triton_flatten_to_batch(scores, metadata.k_cu_seqlens_cuda,
metadata.valid_batch_size,
metadata.max_rocket_k_ctx_len)
score = torch.masked_fill(
score,
attention_mask.view(1, 1, self.window_size, seq_len) == False,
torch.scalar_tensor(float("-inf"),
device=score.device,
dtype=score.dtype))
scores = torch.nn.functional.max_pool1d(
scores,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
stride=1)
score = torch.nn.functional.softmax(score, dim=-1)
selected_prefix_indices = scores.topk(
self.prompt_budget - self.window_size,
dim=-1).indices.sort().values.to(torch.int32)
else:
selected_prefix_indices = torch.empty(
(0, self.num_kv_heads, self.prompt_budget - self.window_size),
device=qkv_input.device,
dtype=torch.int32)
score = torch.masked_fill(
score,
attention_mask.view(1, 1, self.window_size, seq_len) == False,
torch.scalar_tensor(0, device=score.device, dtype=score.dtype))
sparse_kv_offsets = metadata.sparse_offsets_ctx_cuda[:metadata.
num_contexts + 1]
score = score[:, :, -self.window_size:, :-self.window_size].sum(dim=-2)
# Flatten sparse indices
sparse_kv_indices = triton_rocket_batch_to_flatten(
selected_prefix_indices, metadata.prompt_lens_cuda,
metadata.valid_seq_indices_cuda, sparse_kv_offsets,
metadata.num_contexts, metadata.total_sparse_ctx_indices,
self.window_size, self.prompt_budget)
score = score.view(bsz, self.num_kv_heads,
self.num_heads // self.num_kv_heads, -1).sum(dim=2)
score = torch.nn.functional.max_pool1d(score,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
stride=1)
# Select top important tokens from prefix
prefix_len = seq_len - self.window_size
selected_prefix_indices = score.topk(self.prompt_budget -
self.window_size,
dim=-1).indices.sort().values
# Combine selected prefix indices with window indices
window_indices = torch.arange(
prefix_len, seq_len,
device=k.device).unsqueeze(0).unsqueeze(0).expand(
bsz, self.num_kv_heads, -1)
selected_indices = torch.cat([selected_prefix_indices, window_indices],
dim=-1).transpose(1, 2)
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
k_snap = triton_index_gather(k, selected_indices)
# Update KT cache
kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(
self.layer_idx)
k_snap_len = torch.clamp(
metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1],
max=self.prompt_budget).int()
triton_update_kt_cache(
k_snap.squeeze(0).contiguous(),
triton_rocket_update_kt_cache_ctx(
qkv_input.contiguous(),
kt_cache_tensor,
metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1],
k_snap_len,
metadata.kt_cache_block_offsets[:metadata.num_contexts],
metadata.context_cumsum_cuda[:metadata.num_contexts + 1],
sparse_kv_indices,
sparse_kv_offsets,
self.num_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
metadata.kt_tokens_per_block,
metadata.kv_cache_manager.max_kt_blocks_per_seq,
update=False)
)
return selected_indices
# Reduce overhead of post processing
if metadata.valid_batch_size == 0:
return None, None
def _rocketkv_selection(self, q: Tensor, k: Tensor, past_seen_token: int,
metadata: RocketTrtllmAttentionMetadata,
sample_idx: int) -> Tensor:
"""
Implement RocketKV's two-stage selection process for generation phase.
The shape of output is (1, topk, num_kv_heads)
"""
bsz = 1
q_len = q.size(2)
return sparse_kv_indices, sparse_kv_offsets
# Helper functions
def _gather(t: Tensor, dim: int, i: Tensor) -> Tensor:
dim += (dim < 0) * t.ndim
return t.gather(
dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1:]))
@torch.compile(dynamic=True)
def preprocess_for_gen(
self, q: torch.Tensor, k: Optional[torch.Tensor],
metadata: RocketTrtllmAttentionMetadata
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if k is None:
qkv_input = q[metadata.num_ctx_tokens:]
q_hidden_size = self.num_heads * self.head_dim
k_hidden_size = self.num_kv_heads * self.head_dim
q = qkv_input[:, :q_hidden_size]
k = qkv_input[:, q_hidden_size:q_hidden_size + k_hidden_size]
else:
q = q[metadata.num_ctx_tokens:]
k = k[metadata.num_ctx_tokens:]
@torch.compile(disable=not torch.cuda.is_available())
def _scaled_softmax(x: Tensor, divscale: Tensor | float,
dim: int) -> Tensor:
return torch.softmax(x / divscale, dim=dim)
q = q.view(-1, self.num_kv_heads, self.num_heads // self.num_kv_heads,
self.head_dim)
q_abs = torch.abs(q)
q_mask = torch.zeros_like(q)
i1 = torch.topk(q_abs.mean(dim=2, keepdim=True), self.topr,
dim=-1).indices
q_mask.scatter_(-1, i1.expand_as(q[..., :self.topr]), 1)
q_valid = q * q_mask
dim_pos = torch.where(q_valid.sum(dim=2) > 0, self.head_dim,
0).to(torch.int32)
return q_valid, k, dim_pos
def sparse_attn_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if metadata.num_generations == 0:
return None, None
q, k, dim_pos = self.preprocess_for_gen(q, k, metadata)
# Get KT cache for key-token matching
kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(
self.layer_idx)
target_seq_len = past_seen_token + 1 # +1 for current token
# Update KT cache
kt_states = triton_update_kt_cache(
k.squeeze(0).contiguous(), kt_cache_tensor,
metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1],
metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1],
self.page_size, metadata.kt_tokens_per_block,
metadata.kv_cache_manager.max_kt_blocks_per_seq)
kt_states = kt_states.unsqueeze(0).permute(0, 2, 3, 1)
# Update KT cache with new key values
triton_rocket_update_kt_cache_gen(
k,
kt_cache_tensor,
metadata.kt_cache_block_offsets[metadata.num_contexts:],
metadata.kv_lens_cuda_runtime[metadata.num_contexts:],
metadata.page_size,
metadata.kt_tokens_per_block,
metadata.kv_cache_manager.max_kt_blocks_per_seq,
self.num_kv_heads,
self.head_dim,
)
# Reshape query for multi-head processing
qi = q.view(bsz, self.num_kv_heads, self.num_heads // self.num_kv_heads,
q_len, self.head_dim)
qi_abs = torch.abs(qi)
# Perform BMM with updated cache
scores = triton_rocket_paged_kt_cache_bmm(
q,
kt_cache_tensor,
metadata.kt_cache_block_offsets[metadata.num_contexts:],
dim_pos,
metadata.kv_lens_cuda_runtime[metadata.num_contexts:],
metadata.cum_kt_lens_cuda,
metadata.page_size,
metadata.kt_tokens_per_block,
metadata.kv_cache_manager.max_kt_blocks_per_seq,
metadata.total_kt_tokens,
)
# Top-r selection on query features
i1 = torch.topk(qi_abs.mean(dim=2, keepdim=True), self.topr,
dim=-1).indices
qi_hat = _gather(qi, -1, i1)
scores = triton_softmax(scores, metadata.cum_kt_lens_cuda,
metadata.num_generations)
# Generate signed indices for key-token matching
i1_sign = torch.where(
qi_hat.sum(dim=2, keepdim=True) > 0, i1 + self.head_dim,
i1).transpose(-1, -2)
# Mean over num_heads_per_kv for each batch separately
scores = triton_rocket_reduce_scores(
scores,
metadata.cum_kt_lens_cuda,
metadata.num_generations,
self.num_kv_heads,
self.num_heads // self.num_kv_heads,
)
# Gather key tokens and compute attention scores
kt_hat = _gather(kt_states.unsqueeze(2), -2, i1_sign)
qk_hat = qi_hat @ kt_hat
qk_hat = qk_hat.repeat_interleave(self.page_size,
dim=-1)[:, :, :, :, :target_seq_len]
scale = torch.sqrt(self.head_dim *
torch.abs(qi_hat).sum(dim=-1, keepdim=True) /
qi_abs.sum(dim=-1, keepdim=True))
sparse_attn_offsets = metadata.sparse_offsets_gen_cuda[:metadata.
num_generations +
1]
# (1, num_kv_heads, num_heads, target_seq_len)
s_hat = _scaled_softmax(qk_hat, scale, dim=-1)
selected_indices = triton_topk(scores, metadata.cum_kt_lens_cuda,
sparse_attn_offsets,
metadata.total_sparse_gen_indices,
metadata.topk)
topk = min(self.topk, target_seq_len)
i2 = torch.topk(s_hat.mean(dim=2, keepdim=True), topk, dim=-1).indices
iKV = i2[:, :, 0, 0, :].transpose(1, 2) # (1, topk, num_kv_heads)
return iKV
return selected_indices, sparse_attn_offsets
class RocketVanillaAttentionMetadata(VanillaAttentionMetadata):
@ -920,6 +929,7 @@ class RocketKVCacheManager(KVCacheManager):
max_num_draft_tokens: int = 0,
use_mrope: bool = False,
max_beam_width: int = 1,
num_extra_decoding_steps: int = 0,
):
requests = super().add_dummy_requests(
request_ids=request_ids,
@ -929,6 +939,7 @@ class RocketKVCacheManager(KVCacheManager):
max_num_draft_tokens=max_num_draft_tokens,
use_mrope=use_mrope,
max_beam_width=max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps,
)
if prepare_resource:
for req in requests:

View File

@ -197,6 +197,7 @@ class TrtllmAttentionWrapper:
sparse_kv_offsets: Optional[torch.Tensor] = None,
sparse_attn_indices: Optional[torch.Tensor] = None,
sparse_attn_offsets: Optional[torch.Tensor] = None,
sparse_attn_indices_block_size: int = 1,
sparse_mla_topk: int = 0,
**kwargs,
):
@ -241,6 +242,7 @@ class TrtllmAttentionWrapper:
sparse_kv_offsets (torch.Tensor): The batch offsets for the sparse KV indices, with shape of (num_contexts + 1) on GPU.
sparse_attn_indices (torch.Tensor): The sparse indices for the attention layer, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU.
sparse_attn_indices_block_size (int): The granularity of the sparse attention indices, used by block sparse attention.
sparse_mla_topk (int): The topk for the sparse MLA, used by DSA attention.
"""
self.layer_idx = layer_idx
@ -283,6 +285,7 @@ class TrtllmAttentionWrapper:
self.sparse_kv_offsets = sparse_kv_offsets
self.sparse_attn_indices = sparse_attn_indices
self.sparse_attn_offsets = sparse_attn_offsets
self.sparse_attn_indices_block_size = sparse_attn_indices_block_size
self.sparse_mla_topk = sparse_mla_topk
if max_sequence_length > self.rope_params.max_positions:
self.rope_params.max_positions = max_sequence_length
@ -525,6 +528,7 @@ class TrtllmAttentionWrapper:
self.sparse_kv_offsets,
self.sparse_attn_indices,
self.sparse_attn_offsets,
self.sparse_attn_indices_block_size,
self.sparse_mla_topk,
cu_q_seqlens,
cu_kv_seqlens,
@ -1308,11 +1312,14 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
)
sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None
sparse_attn_indices_block_size = 1
if self.sparse_attention_config is not None:
sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(
q, k, metadata, **kwargs)
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
q, k, metadata, **kwargs)
sparse_attn_indices_block_size = self.sparse_attention_config.get_indices_block_size(
)
self.wrapper.plan(
layer_idx=self.get_local_layer_idx(metadata),
@ -1366,8 +1373,10 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
sparse_kv_offsets=sparse_kv_offsets,
sparse_attn_indices=sparse_attn_indices,
sparse_attn_offsets=sparse_attn_offsets,
sparse_attn_indices_block_size=sparse_attn_indices_block_size,
sparse_mla_topk=metadata.sparse_mla_topk if hasattr(
metadata, 'sparse_mla_topk') else 0)
metadata, 'sparse_mla_topk') else 0,
)
out_dtype = None
if out_scale is not None:
if use_nvfp4_output:
@ -1589,18 +1598,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
self.mla_params.v_head_dim,
)
def sparse_attn_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse attn indices. It's implemented in the derived class.
"""
raise NotImplementedError
def sparse_kv_predict(
self,
q: torch.Tensor,
@ -1613,6 +1610,18 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
"""
raise NotImplementedError
def sparse_attn_predict(
self,
q: torch.Tensor,
k: Optional[torch.Tensor],
metadata: TrtllmAttentionMetadata,
**kwargs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Predict sparse attn indices. It's implemented in the derived class.
"""
raise NotImplementedError
def mla_rope_generation(
self,
fused_q: torch.Tensor,

View File

@ -215,6 +215,9 @@ class BaseSparseAttentionConfig(StrictBaseModel):
"""
return True
def get_indices_block_size(self) -> int:
return 1
class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
"""
@ -238,6 +241,9 @@ class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
def get_indices_block_size(self) -> int:
return self.page_size
class DeepSeekSparseAttentionConfig(BaseSparseAttentionConfig):
"""

View File

@ -2,10 +2,18 @@ import json
import os
import pytest
import torch
from utils.llm_data import llm_models_root
import tensorrt_llm
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
from tensorrt_llm._torch.attention_backend.sparse.rocket import (
RocketKVCacheManager, RocketTrtllmAttention, RocketTrtllmAttentionMetadata,
RocketVanillaAttention, RocketVanillaAttentionMetadata)
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
RocketSparseAttentionConfig)
from tensorrt_llm.mapping import Mapping
@pytest.mark.parametrize("backend", ["pytorch"])
@ -25,6 +33,11 @@ def test_model(backend, model_name, attention_backend):
prompt_budget=2048,
)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1, 2, 4, 8, 16],
enable_padding=True,
)
llm = LLM(
model=model_dir,
backend=backend,
@ -32,10 +45,10 @@ def test_model(backend, model_name, attention_backend):
attn_backend=attention_backend,
sparse_attention_config=sparse_attention_config,
max_batch_size=max_batch_size,
max_seq_len=8192,
max_num_tokens=8192,
cuda_graph_config=
None, # sparse attention does not support cuda graph now
max_seq_len=20480,
max_num_tokens=81920,
cuda_graph_config=None
if attention_backend == "VANILLA" else cuda_graph_config,
)
inputs, references = [], []
@ -75,6 +88,559 @@ def test_model(backend, model_name, attention_backend):
assert acc >= 0.9, 'accuracy test of rocketkv sparse attention failed'
def create_rocket_kv_cache_manager(num_layers, num_kv_heads, head_dim,
tokens_per_block, max_seq_len,
max_batch_size, dtype, sparse_attn_config):
mapping = Mapping(world_size=1, tp_size=1, rank=0)
num_blocks = 100
kv_cache_config = KvCacheConfig(max_tokens=num_blocks * tokens_per_block,
enable_block_reuse=False)
kv_cache_manager = RocketKVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
mapping=mapping,
dtype=dtype,
sparse_attn_config=sparse_attn_config,
)
return kv_cache_manager
def create_test_metadata(seq_lens, num_contexts, past_seen_tokens, request_ids,
kv_cache_manager, sparse_attn_config, metadata_cls):
prompt_lens = []
for i, (seq_len, past_token) in enumerate(zip(seq_lens, past_seen_tokens)):
if i < num_contexts:
prompt_lens.append(seq_len)
else:
prompt_lens.append(past_token + seq_len)
metadata = metadata_cls(
seq_lens=torch.tensor(seq_lens, dtype=torch.int),
num_contexts=num_contexts,
kv_cache_params=KVCacheParams(
use_cache=True, num_cached_tokens_per_seq=past_seen_tokens),
max_num_requests=len(seq_lens),
max_num_sequences=len(seq_lens),
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
sparse_attention_config=sparse_attn_config,
)
metadata.prepare()
return metadata
@pytest.mark.parametrize(
"batch_size,num_contexts",
[
(1, 1), # bs=1
(4, 4), # bs=2, context only (2 contexts)
(6, 3), # bs=6, mixed (3 contexts + 3 generations)
])
def test_sparse_kv_predict(batch_size, num_contexts):
"""
Test sparse_kv_predict against vanilla _get_snapkv_indices.
This test verifies that the batched implementation produces the same results
as the single-request implementation for SnapKV sparse attention.
"""
# Test configuration
num_heads = 32
num_kv_heads = 8
head_dim = 128
device = torch.device('cuda')
dtype = torch.bfloat16
sparse_attn_config = RocketSparseAttentionConfig(
window_size=32,
kernel_size=3,
prompt_budget=256,
page_size=3,
)
# Create sequence lengths - mix short and long sequences in context phase
seq_lens = []
past_seen_tokens = []
for i in range(batch_size):
if i < num_contexts:
# Context phase: mix sequences shorter and longer than prompt_budget
if i % 2 == 1 and batch_size > 1:
# Short sequence: seq_len < prompt_budget
seq_lens.append(
torch.randint(sparse_attn_config.prompt_budget // 2,
sparse_attn_config.prompt_budget - 10,
(1, )).item())
else:
# Long sequence: seq_len > prompt_budget
seq_lens.append(
torch.randint(sparse_attn_config.prompt_budget,
sparse_attn_config.prompt_budget + 200,
(1, )).item())
past_seen_tokens.append(0)
else:
# Generation phase: single token
seq_lens.append(1)
past_seen_tokens.append(torch.randint(100, 200, (1, )).item())
request_ids = list(range(batch_size))
num_layers = 1
tokens_per_block = 64
max_seq_len = 4096
if dtype == torch.float16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
vanilla_tokens_per_block = max_seq_len # Each sequence in one block
trtllm_kv_cache_manager = create_rocket_kv_cache_manager(
num_layers, num_kv_heads, head_dim, tokens_per_block, max_seq_len,
batch_size, kv_cache_dtype, sparse_attn_config)
vanilla_kv_cache_manager = create_rocket_kv_cache_manager(
num_layers, num_kv_heads, head_dim, vanilla_tokens_per_block,
max_seq_len, batch_size, kv_cache_dtype, sparse_attn_config)
# Add dummy requests to both cache managers
token_nums = [
seq_len + past_token
for seq_len, past_token in zip(seq_lens, past_seen_tokens)
]
trtllm_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
vanilla_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
trtllm_attn = RocketTrtllmAttention(
layer_idx=0,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=num_kv_heads,
sparse_attention_config=sparse_attn_config,
)
vanilla_attn = RocketVanillaAttention(
layer_idx=0,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=num_kv_heads,
sparse_attention_config=sparse_attn_config,
)
trtllm_metadata = create_test_metadata(seq_lens, num_contexts,
past_seen_tokens, request_ids,
trtllm_kv_cache_manager,
sparse_attn_config,
RocketTrtllmAttentionMetadata)
vanilla_metadata = create_test_metadata(seq_lens, num_contexts,
past_seen_tokens, request_ids,
vanilla_kv_cache_manager,
sparse_attn_config,
RocketVanillaAttentionMetadata)
total_tokens = sum(seq_lens)
qkv = torch.randn(total_tokens, (num_heads + 2 * num_kv_heads) * head_dim,
dtype=dtype,
device=device)
trtllm_sparse_kv_indices, trtllm_sparse_kv_offsets = trtllm_attn.sparse_kv_predict(
qkv, None, trtllm_metadata)
vanilla_sparse_kv_indices_list = []
offset = 0
for i in range(num_contexts):
seq_len = seq_lens[i]
single_qkv = qkv[offset:offset + seq_len]
q, k, _ = single_qkv.split([
num_heads * head_dim, num_kv_heads * head_dim,
num_kv_heads * head_dim
],
dim=-1)
q = q.view(1, seq_len, num_heads, head_dim).transpose(1, 2)
k = k.view(1, seq_len, num_kv_heads, head_dim)
if seq_len <= sparse_attn_config.prompt_budget:
# Short sequences: vanilla returns None, but trtllm returns [0, 1, ..., seq_len-1]
# Generate expected indices for comparison
short_indices = torch.arange(seq_len,
device=device,
dtype=torch.int32).unsqueeze(0).expand(
num_kv_heads, -1)
vanilla_sparse_kv_indices_list.append(short_indices)
else:
vanilla_indices = vanilla_attn._get_snapkv_indices(q, k, i)
if vanilla_indices is not None:
vanilla_indices = vanilla_indices.squeeze(0).transpose(
0, 1).contiguous()
vanilla_sparse_kv_indices_list.append(vanilla_indices)
offset += seq_len
if len(vanilla_sparse_kv_indices_list) > 0:
vanilla_sparse_kv_indices = torch.cat(vanilla_sparse_kv_indices_list,
dim=-1).contiguous()
else:
vanilla_sparse_kv_indices = None
# Compare results
if trtllm_sparse_kv_indices is not None:
assert vanilla_sparse_kv_indices is not None, "Vanilla should also produce indices"
assert trtllm_sparse_kv_indices.shape == vanilla_sparse_kv_indices.shape, \
f"Shape mismatch: {trtllm_sparse_kv_indices.shape} vs {vanilla_sparse_kv_indices.shape}"
# Check indices overlap per batch and per head
num_kv_heads = trtllm_sparse_kv_indices.shape[0]
# trtllm_sparse_kv_offsets tells where each batch's indices start/end
trtllm_offsets = trtllm_sparse_kv_offsets.cpu().tolist()
overlap_ratios = []
batch_overlap_details = []
for batch_idx in range(num_contexts):
start_idx = trtllm_offsets[batch_idx]
end_idx = trtllm_offsets[batch_idx + 1]
end_idx - start_idx
batch_overlaps = []
for head_idx in range(num_kv_heads):
trtllm_batch = trtllm_sparse_kv_indices[
head_idx, start_idx:end_idx].cpu().tolist()
vanilla_batch = vanilla_sparse_kv_indices[
head_idx, start_idx:end_idx].cpu().tolist()
trtllm_set = set(trtllm_batch)
vanilla_set = set(vanilla_batch)
# Calculate overlap
overlap = len(vanilla_set & trtllm_set)
overlap_ratio = overlap / len(vanilla_set) if len(
vanilla_set) > 0 else 1.0
batch_overlaps.append(overlap_ratio)
overlap_ratios.append(overlap_ratio)
avg_batch_overlap = sum(batch_overlaps) / len(batch_overlaps)
batch_overlap_details.append(
f"Batch {batch_idx}: {avg_batch_overlap:.4f}")
avg_overlap_ratio = sum(overlap_ratios) / len(overlap_ratios)
print(f"Average overlap ratio: {avg_overlap_ratio:.4f}")
print(f"Per-batch average: {batch_overlap_details}")
assert avg_overlap_ratio >= 0.98, \
f"Indices overlap ratio {avg_overlap_ratio:.4f} is too low (< 0.98)"
else:
assert vanilla_sparse_kv_indices is None, "Both should return None when no sparse attention is needed"
@pytest.mark.parametrize(
"batch_size,num_contexts",
[
(1, 0), # bs=1, generation only (1 generation)
(2, 0), # bs=2, generation only (2 generations)
(3, 0), # bs=3, generation only (3 generations)
(5, 3), # bs=5, mixed (3 contexts + 2 generations)
(6, 2), # bs=6, mixed (2 ctx + 4 gen)
])
def test_sparse_attn_predict(batch_size, num_contexts):
"""Test sparse_attn_predict against vanilla _rocketkv_selection."""
num_generations = batch_size - num_contexts
# Test configuration
num_heads = 32
num_kv_heads = 8
head_dim = 128
device = torch.device('cuda')
dtype = torch.bfloat16
sparse_attn_config_vanilla = RocketSparseAttentionConfig(
window_size=32,
kernel_size=3,
prompt_budget=256,
page_size=3,
topk=128,
topr=96,
)
sparse_attn_config_trtllm = RocketSparseAttentionConfig(
window_size=32,
kernel_size=3,
prompt_budget=256,
page_size=3,
topk=43,
topr=96,
)
# Create sequence lengths
seq_lens = []
past_seen_tokens = []
for i in range(batch_size):
if i < num_contexts:
# Context phase: longer sequences
seq_lens.append(torch.randint(300, 400, (1, )).item())
past_seen_tokens.append(0)
else:
# Generation phase: single token
seq_lens.append(1)
# 128 is the minimum number of tokens for shape alignment
past_seen_tokens.append(torch.randint(128, 300, (1, )).item())
request_ids = list(range(batch_size))
num_layers = 1
tokens_per_block = 64
max_seq_len = 4096
if dtype == torch.float16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
vanilla_tokens_per_block = max_seq_len # Each sequence in one block
trtllm_kv_cache_manager = create_rocket_kv_cache_manager(
num_layers, num_kv_heads, head_dim, tokens_per_block, max_seq_len,
batch_size, kv_cache_dtype, sparse_attn_config_trtllm)
vanilla_kv_cache_manager = create_rocket_kv_cache_manager(
num_layers, num_kv_heads, head_dim, vanilla_tokens_per_block,
max_seq_len, batch_size, kv_cache_dtype, sparse_attn_config_vanilla)
# Add dummy requests to both cache managers
token_nums = [
seq_len + past_token
for seq_len, past_token in zip(seq_lens, past_seen_tokens)
]
trtllm_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
vanilla_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
trtllm_attn = RocketTrtllmAttention(
layer_idx=0,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=num_kv_heads,
sparse_attention_config=sparse_attn_config_trtllm,
)
vanilla_attn = RocketVanillaAttention(
layer_idx=0,
num_heads=num_heads,
head_dim=head_dim,
num_kv_heads=num_kv_heads,
sparse_attention_config=sparse_attn_config_vanilla,
)
trtllm_metadata = create_test_metadata(seq_lens, num_contexts,
past_seen_tokens, request_ids,
trtllm_kv_cache_manager,
sparse_attn_config_trtllm,
RocketTrtllmAttentionMetadata)
vanilla_metadata = create_test_metadata(seq_lens, num_contexts,
past_seen_tokens, request_ids,
vanilla_kv_cache_manager,
sparse_attn_config_vanilla,
RocketVanillaAttentionMetadata)
total_tokens = sum(seq_lens)
qkv = torch.randn(total_tokens, (num_heads + 2 * num_kv_heads) * head_dim,
dtype=dtype,
device=device)
for layer_idx in range(num_layers):
trtllm_kt_buf = trtllm_kv_cache_manager.get_kt_buffers(layer_idx)
vanilla_kt_buf = vanilla_kv_cache_manager.get_kt_buffers(layer_idx)
torch.nn.init.normal_(trtllm_kt_buf)
# Map trtllm data to vanilla based on block offsets
# TRTLLM: (num_blocks, kt_tokens_per_block, num_kv_heads, 2*head_dim)
# VANILLA: (num_blocks, kt_tokens_per_block_vanilla, num_kv_heads, 2*head_dim)
trtllm_kt_tokens_per_block = trtllm_kv_cache_manager.kt_tokens_per_block
vanilla_kv_cache_manager.kt_tokens_per_block
trtllm_block_offsets = trtllm_metadata.kt_cache_block_offsets
vanilla_block_offsets = vanilla_metadata.kt_cache_block_offsets
for req_idx in range(num_contexts, batch_size):
# Get the number of KT tokens for this request
past_token = past_seen_tokens[req_idx]
num_kt_tokens = (past_token + 1 +
sparse_attn_config_trtllm.page_size -
1) // sparse_attn_config_trtllm.page_size
# Get block offsets for this request
trtllm_blocks = trtllm_block_offsets[req_idx]
vanilla_blocks = vanilla_block_offsets[req_idx]
# Copy data from trtllm blocks to vanilla blocks
kt_token_idx = 0
vanilla_block_idx = 0
# For trtllm: iterate through blocks and copy KT tokens
for trtllm_block_local_idx in range(len(trtllm_blocks)):
if kt_token_idx >= num_kt_tokens:
break
trtllm_block = trtllm_blocks[trtllm_block_local_idx]
if trtllm_block < 0:
break
# How many KT tokens in this trtllm block
kt_tokens_in_this_block = min(trtllm_kt_tokens_per_block,
num_kt_tokens - kt_token_idx)
# Copy to vanilla buffer
vanilla_block = vanilla_blocks[vanilla_block_idx]
if vanilla_block >= 0:
vanilla_kt_buf[vanilla_block, kt_token_idx:kt_token_idx +
kt_tokens_in_this_block].copy_(trtllm_kt_buf[
trtllm_block, :kt_tokens_in_this_block])
kt_token_idx += kt_tokens_in_this_block
trtllm_sparse_attn_indices, trtllm_sparse_attn_offsets = trtllm_attn.sparse_attn_predict(
qkv, None, trtllm_metadata)
vanilla_sparse_attn_indices_list = []
offset = sum(seq_lens[:num_contexts]) # Skip context tokens
for i in range(num_contexts, batch_size):
seq_len = seq_lens[i]
single_qkv = qkv[offset:offset + seq_len]
q, k, _ = single_qkv.split([
num_heads * head_dim, num_kv_heads * head_dim,
num_kv_heads * head_dim
],
dim=-1)
q = q.view(1, seq_len, num_heads, head_dim).transpose(1, 2)
k = k.view(1, seq_len, num_kv_heads, head_dim)
past_seen_token = past_seen_tokens[i]
vanilla_indices = vanilla_attn._rocketkv_selection(
q, k, vanilla_metadata, past_seen_token, i)
vanilla_sparse_attn_indices_list.append(vanilla_indices.squeeze(0))
offset += seq_len
if trtllm_sparse_attn_indices is not None:
assert len(vanilla_sparse_attn_indices_list
) > 0, "Vanilla should also produce indices"
vanilla_sparse_attn_indices = torch.cat(
vanilla_sparse_attn_indices_list, dim=0).transpose(0,
1).contiguous()
# Apply interleave operation to trtllm indices
# For each head, multiply indices by page_size and expand to include all tokens in each page
page_size = sparse_attn_config_trtllm.page_size
num_kv_heads, total_indices = trtllm_sparse_attn_indices.shape
interleaved_indices_list = []
for head_idx in range(num_kv_heads):
head_indices = trtllm_sparse_attn_indices[
head_idx] # Shape: [total_indices]
page_starts = head_indices * page_size # Shape: [total_indices]
expanded_indices = []
for page_start in page_starts:
page_indices = torch.arange(page_start,
page_start + page_size,
device=page_starts.device)
expanded_indices.append(page_indices)
head_interleaved = torch.cat(expanded_indices, dim=0)
# Slice to match vanilla shape
target_length = vanilla_sparse_attn_indices.shape[1]
head_interleaved = head_interleaved[:target_length]
interleaved_indices_list.append(head_interleaved)
# Stack all heads
trtllm_sparse_attn_indices = torch.stack(interleaved_indices_list,
dim=0)
assert trtllm_sparse_attn_indices.shape == vanilla_sparse_attn_indices.shape, \
f"Shape mismatch: {trtllm_sparse_attn_indices.shape} vs {vanilla_sparse_attn_indices.shape}"
trtllm_sparse_attn_indices = trtllm_sparse_attn_indices.sort(
dim=-1).values
vanilla_sparse_attn_indices = vanilla_sparse_attn_indices.sort(
dim=-1).values
# Check indices overlap per batch and per head
num_kv_heads = trtllm_sparse_attn_indices.shape[0]
trtllm_offsets = trtllm_sparse_attn_offsets.cpu().tolist()
overlap_ratios = []
batch_overlap_details = []
num_generations = batch_size - num_contexts
for batch_idx in range(num_generations):
start_idx = trtllm_offsets[batch_idx]
end_idx = trtllm_offsets[batch_idx + 1]
end_idx - start_idx
batch_overlaps = []
for head_idx in range(num_kv_heads):
trtllm_batch = trtllm_sparse_attn_indices[
head_idx, start_idx:end_idx].cpu().tolist()
vanilla_batch = vanilla_sparse_attn_indices[
head_idx, start_idx:end_idx].cpu().tolist()
trtllm_set = set(trtllm_batch)
vanilla_set = set(vanilla_batch)
# Calculate overlap
overlap = len(vanilla_set & trtllm_set)
overlap_ratio = overlap / len(vanilla_set) if len(
vanilla_set) > 0 else 1.0
batch_overlaps.append(overlap_ratio)
overlap_ratios.append(overlap_ratio)
avg_batch_overlap = sum(batch_overlaps) / len(batch_overlaps)
batch_overlap_details.append(
f"Batch {batch_idx}: {avg_batch_overlap:.4f}")
avg_overlap_ratio = sum(overlap_ratios) / len(overlap_ratios)
print(f"Average overlap ratio: {avg_overlap_ratio:.4f}")
print(f"Per-batch average: {batch_overlap_details}")
threshold = 0.94
assert avg_overlap_ratio >= threshold, \
f"Indices overlap ratio {avg_overlap_ratio:.4f} is too low (< {threshold})"
else:
assert len(
vanilla_sparse_attn_indices_list
) == 0, "Both should return None when no sparse attention is needed"
if __name__ == '__main__':
# RocketKV e2e tests
print("=== Testing RocketKV E2E tests ===")
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "VANILLA")
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "TRTLLM")
# Unit tests for sparse_kv_predict
print("\n=== Testing sparse_kv_predict ===")
test_sparse_kv_predict(1, 1) # bs=1, context only
test_sparse_kv_predict(2, 2) # bs=2, context only
test_sparse_kv_predict(6, 3) # bs=6, mixed (3 ctx + 3 gen)
# Unit tests for sparse_attn_predict
print("\n=== Testing sparse_attn_predict ===")
test_sparse_attn_predict(1, 0) # bs=1, generation only
test_sparse_attn_predict(2, 0) # bs=2, generation only
test_sparse_attn_predict(3, 0) # bs=3, generation only
test_sparse_attn_predict(5, 3) # bs=5, mixed (3 ctx + 2 gen)
test_sparse_attn_predict(6, 2) # bs=6, mixed (2 ctx + 4 gen)

View File

@ -0,0 +1,413 @@
import math
import pytest
import torch
from tensorrt_llm._torch.attention_backend.sparse.kernel import (
triton_bmm,
triton_rocket_paged_kt_cache_bmm,
)
def pytorch_reference_bmm(
q: torch.Tensor,
k: torch.Tensor,
q_cu_seqlens: torch.Tensor,
k_cu_seqlens: torch.Tensor,
batch_size: int,
sm_scale: float = None,
causal: bool = False,
) -> torch.Tensor:
num_q_heads, total_q_tokens, head_dim = q.shape
num_k_heads, total_k_tokens, _ = k.shape
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
# Compute q_len_per_seq
q_len_per_seq = total_q_tokens // batch_size
scores = torch.full(
(num_q_heads, q_len_per_seq, total_k_tokens),
float("-inf"),
dtype=torch.float32,
device=q.device,
)
# Process each batch
for batch_idx in range(batch_size):
q_start = q_cu_seqlens[batch_idx].item()
q_end = q_cu_seqlens[batch_idx + 1].item()
k_start = k_cu_seqlens[batch_idx].item()
k_end = k_cu_seqlens[batch_idx + 1].item()
q_seqlen = q_end - q_start
k_seqlen = k_end - k_start
if q_seqlen <= 0 or k_seqlen <= 0:
continue
q_batch = q[:, q_start:q_end, :] # [num_q_heads, q_seqlen, head_dim]
num_heads_per_kv = num_q_heads // num_k_heads
for head_idx in range(num_q_heads):
k_head_idx = head_idx // num_heads_per_kv
k_batch = k[k_head_idx, k_start:k_end, :] # [k_seqlen, head_dim]
qk = torch.matmul(q_batch[head_idx], k_batch.T) * sm_scale
if causal:
causal_mask = torch.triu(
torch.ones(q_seqlen, k_seqlen, device=q.device, dtype=torch.bool), diagonal=1
)
qk = qk.masked_fill(causal_mask, float("-inf"))
scores[head_idx, :q_seqlen, k_start:k_end] = qk
return scores
def create_kt_cache_from_k(
k: torch.Tensor,
kv_lens: torch.Tensor,
kt_page_size: int,
tokens_per_block: int,
num_kv_heads: int,
head_dim: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Create paged KT cache tensor from continuous K tensor.
Args:
k: Key tensor [num_gen_tokens, num_kv_heads * head_dim]
kv_lens: Sequence lengths [num_gen_tokens]
kt_page_size: Page size for KT tokens
tokens_per_block: Tokens per cache block
num_kv_heads: Number of KV heads
head_dim: Head dimension
Returns:
kt_cache_tensor: KT cache [num_blocks, tokens_per_block, num_kv_heads, 2*head_dim]
kt_cache_block_offsets: Block offsets [num_gen_tokens, max_kt_blocks_per_seq]
max_kt_blocks_per_seq: Maximum KT blocks per sequence
"""
num_gen_tokens = k.shape[0]
device = k.device
dtype = k.dtype
# Calculate number of kt tokens per sequence
num_kt_tokens_per_seq = [
(kv_len.item() + kt_page_size - 1) // kt_page_size for kv_len in kv_lens
]
max_kt_tokens = max(num_kt_tokens_per_seq)
max_kt_blocks_per_seq = (max_kt_tokens + tokens_per_block - 1) // tokens_per_block
# Calculate total number of blocks needed
total_blocks_needed = sum(
(kt_tokens + tokens_per_block - 1) // tokens_per_block
for kt_tokens in num_kt_tokens_per_seq
)
# Create KT cache tensor
kt_cache_tensor = torch.zeros(
(total_blocks_needed, tokens_per_block, num_kv_heads, 2 * head_dim),
device=device,
dtype=dtype,
)
# Create block offsets tensor
kt_cache_block_offsets = torch.full(
(num_gen_tokens, max_kt_blocks_per_seq), -1, dtype=torch.int32, device=device
)
# Fill KT cache and block offsets
current_block_idx = 0
for seq_idx in range(num_gen_tokens):
kv_len = kv_lens[seq_idx].item()
num_kt_tokens = num_kt_tokens_per_seq[seq_idx]
# Reshape k for this sequence: [num_kv_heads, head_dim]
k_seq = k[seq_idx].view(num_kv_heads, head_dim)
# Process each kt token (page)
for kt_idx in range(num_kt_tokens):
page_start = kt_idx * kt_page_size
# For simplicity, we use the first token in the page as representative
# In real usage, this would be min/max over the page
# Here we just replicate the first token's value for testing
token_idx = page_start
if token_idx < kv_len:
k_val = k_seq # [num_kv_heads, head_dim]
# Store k_min and k_max (for testing, we use same value)
kt_min = k_val
kt_max = k_val
# Determine which block this kt token belongs to
block_offset = kt_idx // tokens_per_block
token_offset_in_block = kt_idx % tokens_per_block
# Assign block index if not already assigned
if kt_cache_block_offsets[seq_idx, block_offset] < 0:
kt_cache_block_offsets[seq_idx, block_offset] = current_block_idx
current_block_idx += 1
block_idx = kt_cache_block_offsets[seq_idx, block_offset].item()
# Store in cache: [block, token_in_block, head, 2*head_dim]
kt_cache_tensor[block_idx, token_offset_in_block, :, :head_dim] = kt_min
kt_cache_tensor[block_idx, token_offset_in_block, :, head_dim:] = kt_max
return kt_cache_tensor, kt_cache_block_offsets, max_kt_blocks_per_seq
def pytorch_reference_paged_kt_cache_bmm(
q: torch.Tensor,
k: torch.Tensor,
dim_pos: torch.Tensor,
kv_lens: torch.Tensor,
kt_page_size: int,
sm_scale: float = None,
) -> torch.Tensor:
num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim = q.shape
total_num_heads = num_kv_heads * num_heads_per_kv
device = q.device
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
max_kt_tokens = max((kv_len.item() + kt_page_size - 1) // kt_page_size for kv_len in kv_lens)
total_kt_tokens = num_gen_tokens * max_kt_tokens
scores = torch.zeros((total_num_heads, 1, total_kt_tokens), dtype=torch.float32, device=device)
# Process each generation token
for batch_idx in range(num_gen_tokens):
kv_len = kv_lens[batch_idx].item()
num_kt_tokens = (kv_len + kt_page_size - 1) // kt_page_size
q_batch = q[batch_idx] # [num_kv_heads, num_heads_per_kv, head_dim]
k_batch = k[batch_idx].view(num_kv_heads, head_dim) # [num_kv_heads, head_dim]
dim_pos_batch = dim_pos[batch_idx] # [num_kv_heads, head_dim]
output_offset = batch_idx * max_kt_tokens
for kv_head_idx in range(num_kv_heads):
for q_head_idx in range(num_heads_per_kv):
global_head_idx = kv_head_idx * num_heads_per_kv + q_head_idx
q_vec = q_batch[kv_head_idx, q_head_idx] # [head_dim]
k_vec = k_batch[kv_head_idx] # [head_dim]
dim_pos_vec = dim_pos_batch[kv_head_idx] # [head_dim]
# Simulate KT selection based on dim_pos
k_selected = torch.where(dim_pos_vec > 0, k_vec, k_vec)
# Compute score for each kt token (simplified)
for kt_idx in range(num_kt_tokens):
score = torch.dot(q_vec, k_selected) * sm_scale
scores[global_head_idx, 0, output_offset + kt_idx] = score
return scores
@pytest.mark.parametrize(
"batch_size,q_len_per_seq,k_lens,num_q_heads,num_kv_heads,head_dim,causal",
[
# Single batch
(1, 32, [128], 8, 8, 128, False),
(1, 32, [128], 8, 8, 128, True),
# Multiple batches with different k_len
(3, 32, [64, 128, 256], 8, 8, 128, False),
(4, 16, [100, 200, 150, 300], 32, 8, 128, False),
# Edge cases
(2, 1, [10, 20], 8, 8, 128, False), # q_len=1
(2, 64, [64, 128], 16, 4, 64, True), # Different head_dim with causal
],
)
def test_triton_bmm(batch_size, q_len_per_seq, k_lens, num_q_heads, num_kv_heads, head_dim, causal):
device = torch.device("cuda")
dtype = torch.float32
total_q_tokens = batch_size * q_len_per_seq
q_cu_seqlens = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * q_len_per_seq
k_lens_tensor = torch.tensor(k_lens, dtype=torch.int32, device=device)
k_cu_seqlens = torch.cat(
[torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(k_lens_tensor, dim=0)]
)
total_k_tokens = k_cu_seqlens[-1].item()
q = torch.randn((num_q_heads, total_q_tokens, head_dim), dtype=dtype, device=device)
k = torch.randn((num_kv_heads, total_k_tokens, head_dim), dtype=dtype, device=device)
triton_scores = triton_bmm(
q=q,
k=k,
q_cu_seqlens=q_cu_seqlens,
k_cu_seqlens=k_cu_seqlens,
batch_size=batch_size,
sm_scale=None,
causal=causal,
)
reference_scores = pytorch_reference_bmm(
q=q,
k=k,
q_cu_seqlens=q_cu_seqlens,
k_cu_seqlens=k_cu_seqlens,
batch_size=batch_size,
sm_scale=None,
causal=causal,
)
# Compare results
# Handle -inf values separately
triton_finite = torch.isfinite(triton_scores)
reference_finite = torch.isfinite(reference_scores)
# Check that inf/finite masks match
assert torch.all(triton_finite == reference_finite), (
"Finite/infinite mask mismatch between Triton and reference"
)
# Compare finite values
if triton_finite.any():
max_diff = torch.max(
torch.abs(triton_scores[triton_finite] - reference_scores[reference_finite])
).item()
print(f"Max absolute difference: {max_diff:.6f}")
assert max_diff < 0.01, f"Max difference {max_diff} exceeds threshold"
@pytest.mark.parametrize(
"batch_size,kv_lens,num_kv_heads,num_heads_per_kv,head_dim,kt_page_size,tokens_per_block",
[
# Single batch
(1, [128], 8, 4, 128, 3, 64),
# Multiple batches with different kv_len
(3, [100, 200, 150], 8, 4, 128, 3, 64),
],
)
def test_triton_rocket_paged_kt_cache_bmm(
batch_size, kv_lens, num_kv_heads, num_heads_per_kv, head_dim, kt_page_size, tokens_per_block
):
device = torch.device("cuda")
dtype = torch.bfloat16
num_gen_tokens = batch_size
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32, device=device)
# Create Q tensor: [num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim]
q = torch.randn(
(num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim), dtype=dtype, device=device
)
# Create K tensor for reference: [num_gen_tokens, num_kv_heads * head_dim]
k = torch.randn((num_gen_tokens, num_kv_heads * head_dim), dtype=dtype, device=device)
# Create dim_pos: [num_gen_tokens, num_kv_heads, head_dim]
# Randomly set some dimensions to head_dim (positive) and others to 0
dim_pos = torch.zeros(
(num_gen_tokens, num_kv_heads, head_dim), dtype=torch.int32, device=device
)
for i in range(num_gen_tokens):
for j in range(num_kv_heads):
num_positive = torch.randint(0, head_dim, (1,)).item()
positive_indices = torch.randperm(head_dim)[:num_positive]
dim_pos[i, j, positive_indices] = head_dim
# Create paged KT cache
kt_cache_tensor, kt_cache_block_offsets, max_kt_blocks_per_seq = create_kt_cache_from_k(
k=k,
kv_lens=kv_lens_tensor,
kt_page_size=kt_page_size,
tokens_per_block=tokens_per_block,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
)
# Calculate output offsets
max_kt_tokens = max((kv_len + kt_page_size - 1) // kt_page_size for kv_len in kv_lens)
total_kt_tokens = num_gen_tokens * max_kt_tokens
output_offsets = (
torch.arange(0, num_gen_tokens + 1, device=device, dtype=torch.int32) * max_kt_tokens
)
triton_scores = triton_rocket_paged_kt_cache_bmm(
q=q,
kt_cache_tensor=kt_cache_tensor,
kt_cache_block_offsets=kt_cache_block_offsets,
dim_pos=dim_pos,
kv_lens=kv_lens_tensor,
output_offsets=output_offsets,
kt_page_size=kt_page_size,
tokens_per_block=tokens_per_block,
max_kt_blocks_per_seq=max_kt_blocks_per_seq,
total_kt_tokens=total_kt_tokens,
sm_scale=None,
)
reference_scores = pytorch_reference_paged_kt_cache_bmm(
q=q,
k=k,
dim_pos=dim_pos,
kv_lens=kv_lens_tensor,
kt_page_size=kt_page_size,
sm_scale=None,
)
# Compare results
# Only compare non-zero entries (valid kt tokens)
mask = torch.zeros_like(triton_scores, dtype=torch.bool)
for batch_idx in range(num_gen_tokens):
kv_len = kv_lens[batch_idx]
num_kt_tokens = (kv_len + kt_page_size - 1) // kt_page_size
offset = batch_idx * max_kt_tokens
mask[:, :, offset : offset + num_kt_tokens] = True
triton_valid = triton_scores[mask]
reference_valid = reference_scores[mask]
max_diff = torch.max(torch.abs(triton_valid - reference_valid)).item()
print(f"Max absolute difference: {max_diff:.6f}")
assert max_diff < 0.05, f"Max difference {max_diff} exceeds threshold"
if __name__ == "__main__":
print("\n" + "=" * 80)
print("Testing Triton BMM Kernel")
print("=" * 80)
# Test triton_bmm
print("\n--- Single batch, non-causal ---")
test_triton_bmm(1, 32, [128], 8, 8, 128, False)
print("\n--- Single batch, causal ---")
test_triton_bmm(1, 32, [128], 8, 8, 128, True)
print("\n--- Multiple batches, different k_len ---")
test_triton_bmm(3, 32, [64, 128, 256], 8, 8, 128, False)
print("\n" + "=" * 80)
print("Testing Triton Rocket Paged KT Cache BMM Kernel")
print("=" * 80)
# Test triton_rocket_paged_kt_cache_bmm
print("\n--- Single batch ---")
test_triton_rocket_paged_kt_cache_bmm(1, [128], 8, 4, 128, 3, 64)
print("\n--- Multiple batches, different kv_len ---")
test_triton_rocket_paged_kt_cache_bmm(3, [100, 200, 150], 8, 4, 128, 3, 64)
print("\n" + "=" * 80)
print("All tests passed!")
print("=" * 80)

View File

@ -0,0 +1,228 @@
import pytest
import torch
from tensorrt_llm._torch.attention_backend.sparse.kernel import triton_topk
def pytorch_reference_topk(
input_tensor: torch.Tensor,
kv_offsets: torch.Tensor,
kv_lens: torch.Tensor,
topk: int,
) -> torch.Tensor:
"""
Args:
input_tensor: Input scores [num_kv_heads, sum(kv_lens)]
kv_offsets: KV offsets [batch_size + 1]
kv_lens: KV lengths [batch_size]
topk: TopK parameter
Returns:
output_indices: Padded indices [num_kv_heads, batch_size, topk]
"""
num_kv_heads = input_tensor.shape[0]
batch_size = len(kv_lens)
device = input_tensor.device
# Compute max sequence length for padding
max_seq_len = kv_lens.max().item()
# Ensure padding size >= topk
pad_size = max(max_seq_len, topk)
# Create padded tensor [num_kv_heads, batch_size, pad_size]
padded_tensor = torch.full(
(num_kv_heads, batch_size, pad_size), float("-inf"), dtype=input_tensor.dtype, device=device
)
# Fill in actual values
for batch_idx in range(batch_size):
start = kv_offsets[batch_idx].item()
end = kv_offsets[batch_idx + 1].item()
seq_len = kv_lens[batch_idx].item()
for head_idx in range(num_kv_heads):
padded_tensor[head_idx, batch_idx, :seq_len] = input_tensor[head_idx, start:end]
# Perform batch topk: [num_kv_heads, batch_size, pad_size] -> [num_kv_heads, batch_size, topk]
topk_values, topk_indices = torch.topk(
padded_tensor,
k=topk,
dim=-1,
largest=True,
)
# Mask out invalid indices based on each batch's seq_len
seq_lens_expanded = kv_lens.to(device).unsqueeze(0).unsqueeze(-1) # [1, batch_size, 1]
# topk_indices: [num_kv_heads, batch_size, topk]
mask = topk_indices >= seq_lens_expanded
topk_indices.masked_fill_(mask, -1)
return topk_indices
def triton_topk_wrapper(
input_tensor: torch.Tensor,
kv_offsets: torch.Tensor,
kv_lens: torch.Tensor,
topk: int,
) -> torch.Tensor:
"""
Args:
input_tensor: Input scores [num_kv_heads, sum(kv_lens)]
kv_offsets: KV offsets [batch_size + 1]
kv_lens: KV lengths [batch_size]
topk: TopK parameter
Returns:
output_indices: Padded indices [num_kv_heads, batch_size, topk]
"""
num_kv_heads = input_tensor.shape[0]
batch_size = len(kv_lens)
device = input_tensor.device
sparse_lens = torch.tensor(
[min(topk, seq_len.item()) for seq_len in kv_lens], dtype=torch.int32, device=device
)
sparse_offsets = torch.cat(
[torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(sparse_lens, dim=0)]
).to(device)
total_sparse_indices = sparse_offsets[-1].item()
output_indices_flat = triton_topk(
input_tensor, kv_offsets, sparse_offsets, total_sparse_indices, topk
)
# Convert flat format to padded format [num_kv_heads, batch_size, topk]
output_indices_padded = torch.full(
(num_kv_heads, batch_size, topk), -1, dtype=torch.int32, device=device
)
for batch_idx in range(batch_size):
start = sparse_offsets[batch_idx].item()
end = sparse_offsets[batch_idx + 1].item()
actual_len = end - start
for head_idx in range(num_kv_heads):
output_indices_padded[head_idx, batch_idx, :actual_len] = output_indices_flat[
head_idx, start:end
]
return output_indices_padded
def compute_overlap_ratio(
triton_indices: torch.Tensor,
reference_indices: torch.Tensor,
kv_lens: torch.Tensor,
) -> float:
"""
Args:
triton_indices: Triton topk results [num_kv_heads, batch_size, topk]
reference_indices: Reference topk results [num_kv_heads, batch_size, topk]
kv_lens: KV lengths [batch_size]
Returns:
Average overlap ratio across all batches and heads
"""
num_kv_heads = triton_indices.shape[0]
batch_size = triton_indices.shape[1]
overlap_ratios = []
# Compare batch by batch
for batch_idx in range(batch_size):
for head_idx in range(num_kv_heads):
# Extract indices for this batch and head
triton_batch = triton_indices[head_idx, batch_idx, :].cpu().tolist()
reference_batch = reference_indices[head_idx, batch_idx, :].cpu().tolist()
# Filter out -1 (invalid/padding indices)
triton_set = set([x for x in triton_batch if x >= 0])
reference_set = set([x for x in reference_batch if x >= 0])
if len(reference_set) > 0:
overlap = len(triton_set & reference_set)
overlap_ratio = overlap / len(reference_set)
overlap_ratios.append(overlap_ratio)
if len(overlap_ratios) == 0:
return 1.0
return sum(overlap_ratios) / len(overlap_ratios)
@pytest.mark.parametrize(
"batch_size,seq_lens,num_kv_heads,topk",
[
# Single batch, seq_len > topk
(1, [3000], 1, 2048),
# Single batch, seq_len < topk
(1, [1000], 8, 2048),
# Multiple batches, mixed seq_len (some < topk, some > topk)
(6, [50, 150, 80, 300, 100, 256], 8, 128),
(6, [1000, 2500, 1500, 3000, 1800, 4000], 8, 2048),
],
)
def test_topk_kernel(batch_size, seq_lens, num_kv_heads, topk):
device = torch.device("cuda")
dtype = torch.float32
kv_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
kv_offsets = torch.cat(
[torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(kv_lens, dim=0)]
).to(device)
total_tokens = kv_offsets[-1].item()
input_tensor = torch.randn((num_kv_heads, total_tokens), dtype=dtype, device=device)
triton_output = triton_topk_wrapper(
input_tensor=input_tensor,
kv_offsets=kv_offsets,
kv_lens=kv_lens,
topk=topk,
)
reference_output = pytorch_reference_topk(
input_tensor=input_tensor,
kv_offsets=kv_offsets,
kv_lens=kv_lens,
topk=topk,
)
overlap_ratio = compute_overlap_ratio(
triton_output,
reference_output,
kv_lens,
)
min_threshold = 0.99
print(f"overlap_ratio: {overlap_ratio}")
assert overlap_ratio >= min_threshold, (
f"Overlap ratio {overlap_ratio:.4f} is too low (< {min_threshold})"
)
if __name__ == "__main__":
print("\n" + "=" * 80)
print("Testing Triton TopK Kernel")
print("=" * 80)
# Single batch tests
print("\n--- Single batch, seq_len > topk ---")
test_topk_kernel(1, [3000], 8, 2048)
print("\n--- Single batch, seq_len < topk ---")
test_topk_kernel(1, [1000], 8, 2048)
print("\n--- Multiple batches, mixed seq_len ---")
test_topk_kernel(6, [50, 150, 80, 300, 100, 256], 8, 128)
test_topk_kernel(6, [1000, 2500, 1500, 3000, 1800, 4000], 8, 2048)
print("\n" + "=" * 80)
print("All tests passed!")
print("=" * 80)