mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None] [feat] Use triton kernels for RocketKV prediction module (#8682)
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
parent
cc4c980e03
commit
f07e9977c6
@ -20,7 +20,7 @@ namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
template <int THREADS_PER_BLOCK>
|
||||
template <int THREADS_PER_BLOCK, int MAX_NUM_PAGES>
|
||||
__global__ void gatherKvPageOffsetsKernel(
|
||||
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
|
||||
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
|
||||
@ -32,23 +32,33 @@ __global__ void gatherKvPageOffsetsKernel(
|
||||
// Each CUDA block processes one sequence from the batch for one head.
|
||||
int32_t const head_idx = blockIdx.x;
|
||||
int32_t const batch_idx = blockIdx.y;
|
||||
int32_t const indices_block_size = sparse_params.sparse_attn_indices_block_size;
|
||||
if (batch_idx >= batch_size)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Shared memory for reduction.
|
||||
__shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
|
||||
using BlockScan = cub::BlockScan<int32_t, THREADS_PER_BLOCK>;
|
||||
using BlockReduce = cub::BlockReduce<Pair, THREADS_PER_BLOCK>;
|
||||
|
||||
__shared__ typename BlockScan::TempStorage temp_storage_scan;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
|
||||
|
||||
__shared__ int32_t s_page_mask[MAX_NUM_PAGES];
|
||||
__shared__ int32_t s_cu_page_mask[MAX_NUM_PAGES];
|
||||
__shared__ int32_t s_scan_total; // Store total count from scan
|
||||
|
||||
// Get the range of sparse indices and the sequence length.
|
||||
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
|
||||
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
|
||||
int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size];
|
||||
int32_t const num_sparse_pages = end_offset - start_offset;
|
||||
int32_t const sparse_attn_indices_stride = sparse_params.sparse_attn_indices_stride;
|
||||
int32_t const num_sparse_indices = end_offset - start_offset;
|
||||
int32_t const original_seq_len = seq_lengths[batch_idx];
|
||||
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
|
||||
int32_t const page_loops = (ori_valid_pages + MAX_NUM_PAGES - 1) / MAX_NUM_PAGES;
|
||||
|
||||
// Get global sparse index.
|
||||
int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
|
||||
int32_t const sparse_idx_global = head_idx * sparse_attn_indices_stride + start_offset;
|
||||
|
||||
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
|
||||
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
|
||||
@ -58,56 +68,119 @@ __global__ void gatherKvPageOffsetsKernel(
|
||||
int32_t local_max_page_index = -1;
|
||||
int32_t local_num_valid_pages = 0;
|
||||
|
||||
// Perform the gather operation.
|
||||
for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
|
||||
int32_t src_page_idx_offset = 0;
|
||||
int32_t dst_page_idx_offset = 0;
|
||||
for (int32_t loop_idx = 0; loop_idx < page_loops; loop_idx++)
|
||||
{
|
||||
// Get the source idx and offset.
|
||||
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
|
||||
if (src_idx < 0)
|
||||
src_page_idx_offset = loop_idx * MAX_NUM_PAGES;
|
||||
int32_t loop_num_valid_pages = min(MAX_NUM_PAGES, ori_valid_pages - src_page_idx_offset);
|
||||
for (int32_t i = threadIdx.x; i < MAX_NUM_PAGES; i += blockDim.x)
|
||||
{
|
||||
continue;
|
||||
s_page_mask[i] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int32_t i = threadIdx.x; i < num_sparse_indices; i += blockDim.x)
|
||||
{
|
||||
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
|
||||
int32_t const src_idx_start = src_idx * indices_block_size;
|
||||
int32_t const src_idx_end = min(src_idx_start + indices_block_size, original_seq_len);
|
||||
for (int32_t j = src_idx_start; j < src_idx_end; j++)
|
||||
{
|
||||
int32_t const src_page_idx = j / tokens_per_page;
|
||||
if (src_page_idx >= src_page_idx_offset && src_page_idx < src_page_idx_offset + loop_num_valid_pages)
|
||||
{
|
||||
atomicExch(&s_page_mask[src_page_idx - src_page_idx_offset], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Handle case when loop_num_valid_pages > blockDim.x by processing in chunks
|
||||
int32_t scan_offset = 0;
|
||||
int32_t const scan_chunks = (loop_num_valid_pages + blockDim.x - 1) / blockDim.x;
|
||||
|
||||
for (int32_t chunk_idx = 0; chunk_idx < scan_chunks; chunk_idx++)
|
||||
{
|
||||
int32_t const chunk_start = chunk_idx * blockDim.x;
|
||||
int32_t const chunk_size = min((int32_t) blockDim.x, loop_num_valid_pages - chunk_start);
|
||||
|
||||
int32_t thread_data = (threadIdx.x < chunk_size) ? s_page_mask[chunk_start + threadIdx.x] : 0;
|
||||
int32_t thread_output;
|
||||
int32_t aggregate;
|
||||
|
||||
BlockScan(temp_storage_scan).ExclusiveSum(thread_data, thread_output, aggregate);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < chunk_size)
|
||||
{
|
||||
s_cu_page_mask[chunk_start + threadIdx.x] = thread_output + scan_offset;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Update scan offset for next chunk
|
||||
scan_offset += aggregate;
|
||||
}
|
||||
|
||||
// Update the local max page index.
|
||||
local_max_page_index = max(local_max_page_index, src_idx);
|
||||
local_num_valid_pages++;
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
s_scan_total = scan_offset;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Get the source and destination offsets.
|
||||
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
|
||||
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
|
||||
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
|
||||
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
|
||||
// Perform the gather operation.
|
||||
for (int32_t i = threadIdx.x; i < loop_num_valid_pages; i += blockDim.x)
|
||||
{
|
||||
// Skip if the page is not valid.
|
||||
if (s_page_mask[i] == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Perform the gather operation: read from the sparse location and write to the dense location.
|
||||
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
|
||||
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
|
||||
int32_t const src_idx = src_page_idx_offset + i;
|
||||
int32_t const dst_idx = dst_page_idx_offset + s_cu_page_mask[i];
|
||||
|
||||
local_max_page_index = max(local_max_page_index, src_idx);
|
||||
local_num_valid_pages++;
|
||||
|
||||
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
|
||||
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
|
||||
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + dst_idx;
|
||||
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + dst_idx;
|
||||
|
||||
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
|
||||
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Update dst offset using the total count from scan
|
||||
dst_page_idx_offset += s_scan_total;
|
||||
}
|
||||
|
||||
// Reduce the local max page indices and number of valid pages.
|
||||
Pair local_pair = {local_max_page_index, local_num_valid_pages};
|
||||
Pair result = cub::BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage).Reduce(local_pair, PairReduceOp());
|
||||
Pair result = BlockReduce(temp_storage_reduce).Reduce(local_pair, PairReduceOp());
|
||||
|
||||
// Update sequence length for this head and batch.
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int32_t const max_page_index = result.max_val;
|
||||
int32_t const num_valid_pages = result.sum_val;
|
||||
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
|
||||
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
|
||||
int32_t seq_len = 0;
|
||||
if (num_valid_pages > 0)
|
||||
{
|
||||
int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
|
||||
int32_t seq_len_remain = original_seq_len % tokens_per_page;
|
||||
if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0)
|
||||
if (max_page_index == ori_valid_pages - 1)
|
||||
{
|
||||
seq_len += tokens_per_page - seq_len_remain;
|
||||
seq_len = (num_valid_pages - 1) * tokens_per_page
|
||||
+ (original_seq_len - (ori_valid_pages - 1) * tokens_per_page);
|
||||
}
|
||||
else
|
||||
{
|
||||
seq_len = num_valid_pages * tokens_per_page;
|
||||
}
|
||||
output_seq_lengths[seq_len_offset] = seq_len;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_seq_lengths[seq_len_offset] = 0;
|
||||
}
|
||||
output_seq_lengths[seq_len_offset] = seq_len;
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,11 +194,8 @@ void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_
|
||||
dim3 grid(num_head_kv, batch_size, 1);
|
||||
// The block.
|
||||
dim3 block(256, 1, 1);
|
||||
// Shared memory size.
|
||||
size_t smem_size = sizeof(Pair) * 256;
|
||||
|
||||
// Launch the kernel.
|
||||
gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
|
||||
gatherKvPageOffsetsKernel<256, 512><<<grid, block, 0, stream>>>(output_kv_page_offsets, output_seq_lengths,
|
||||
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
|
||||
}
|
||||
} // namespace kernels
|
||||
|
||||
@ -35,6 +35,9 @@ struct SparseAttentionParams
|
||||
int32_t sparse_mla_topk{0}; // for DSA attention
|
||||
void* sparse_mla_kv_cache_pool{nullptr}; // for DSA attention
|
||||
|
||||
int32_t sparse_attn_indices_block_size{1};
|
||||
int32_t sparse_attn_indices_stride{0};
|
||||
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream ss;
|
||||
@ -43,7 +46,9 @@ struct SparseAttentionParams
|
||||
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
|
||||
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl
|
||||
<< "sparse_mla_topk: " << this->sparse_mla_topk << std::endl
|
||||
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl;
|
||||
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl
|
||||
<< "sparse_attn_indices_block_size: " << this->sparse_attn_indices_block_size << std::endl
|
||||
<< "sparse_attn_indices_stride: " << this->sparse_attn_indices_stride << std::endl;
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
@ -64,10 +64,11 @@ void initBindings(nb::module_& m)
|
||||
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
|
||||
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_kv_indices") = std::nullopt,
|
||||
nb::arg("sparse_kv_offsets") = std::nullopt, nb::arg("sparse_attn_indices") = std::nullopt,
|
||||
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_mla_topk") = std::nullopt,
|
||||
nb::arg("cu_q_seqlens") = std::nullopt, nb::arg("cu_kv_seqlens") = std::nullopt,
|
||||
nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt,
|
||||
nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt,
|
||||
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());
|
||||
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_attn_indices_block_size"),
|
||||
nb::arg("sparse_mla_topk") = std::nullopt, nb::arg("cu_q_seqlens") = std::nullopt,
|
||||
nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt,
|
||||
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
|
||||
nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
}
|
||||
} // namespace tensorrt_llm::nanobind::thop
|
||||
|
||||
@ -64,10 +64,11 @@ void initBindings(pybind11::module_& m)
|
||||
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
|
||||
py::arg("spec_decoding_tensor_params"), py::arg("sparse_kv_indices") = std::nullopt,
|
||||
py::arg("sparse_kv_offsets") = std::nullopt, py::arg("sparse_attn_indices") = std::nullopt,
|
||||
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_mla_topk") = std::nullopt,
|
||||
py::arg("cu_q_seqlens") = std::nullopt, py::arg("cu_kv_seqlens") = std::nullopt,
|
||||
py::arg("fmha_scheduler_counter") = std::nullopt, py::arg("mla_bmm1_scale") = std::nullopt,
|
||||
py::arg("mla_bmm2_scale") = std::nullopt, py::arg("quant_q_buffer") = std::nullopt,
|
||||
"Multi-head attention operation", py::call_guard<py::gil_scoped_release>());
|
||||
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_attn_indices_block_size"),
|
||||
py::arg("sparse_mla_topk") = std::nullopt, py::arg("cu_q_seqlens") = std::nullopt,
|
||||
py::arg("cu_kv_seqlens") = std::nullopt, py::arg("fmha_scheduler_counter") = std::nullopt,
|
||||
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
|
||||
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::thop
|
||||
|
||||
@ -86,10 +86,11 @@ public:
|
||||
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
|
||||
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
|
||||
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
|
||||
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
|
||||
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
|
||||
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const
|
||||
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
|
||||
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
|
||||
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
|
||||
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
|
||||
std::optional<torch::Tensor> quant_q_buffer) const
|
||||
= 0;
|
||||
};
|
||||
|
||||
@ -146,10 +147,11 @@ public:
|
||||
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
|
||||
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
|
||||
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
|
||||
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
|
||||
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
|
||||
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const override
|
||||
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
|
||||
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
|
||||
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
|
||||
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
|
||||
std::optional<torch::Tensor> quant_q_buffer) const override
|
||||
{
|
||||
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
|
||||
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
|
||||
@ -395,6 +397,9 @@ public:
|
||||
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
|
||||
op.mRuntimeSparseAttentionParams.sparse_attn_offsets
|
||||
= sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
|
||||
op.mRuntimeSparseAttentionParams.sparse_attn_indices_block_size = sparse_attn_indices_block_size;
|
||||
op.mRuntimeSparseAttentionParams.sparse_attn_indices_stride
|
||||
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().size(-1) : 0;
|
||||
if (op.isMLAEnabled() && op.mUseSparseAttention)
|
||||
{
|
||||
op.mRuntimeSparseAttentionParams.sparse_mla_topk = sparse_mla_topk;
|
||||
@ -589,10 +594,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
|
||||
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
|
||||
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
|
||||
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
|
||||
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
|
||||
std::optional<torch::Tensor> quant_q_buffer)
|
||||
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> sparse_mla_topk,
|
||||
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
|
||||
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
|
||||
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer)
|
||||
{
|
||||
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
|
||||
// Use these tensors to infer if the attention is using KV cache
|
||||
@ -847,8 +852,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
||||
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
|
||||
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
|
||||
quant_q_buffer);
|
||||
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
|
||||
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
|
||||
}
|
||||
|
||||
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
|
||||
@ -866,8 +871,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
||||
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
|
||||
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
|
||||
quant_q_buffer);
|
||||
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
|
||||
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
|
||||
|
||||
@ -63,9 +63,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
|
||||
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
|
||||
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
|
||||
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
|
||||
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
|
||||
std::optional<torch::Tensor> quant_q_buffer);
|
||||
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> sparse_mla_topk,
|
||||
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
|
||||
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
|
||||
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer);
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
@ -36,14 +36,16 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
||||
constexpr int num_head_kv = 4;
|
||||
constexpr int max_num_pages_per_seq = 8;
|
||||
constexpr int tokens_per_page = 64;
|
||||
constexpr int total_sparse_pages = max_batch_size * max_num_pages_per_seq; // Total sparse pages across all batches
|
||||
// Batch 0 has 8 sparse tokens, Batch 1 has 6 sparse tokens, total = 14
|
||||
constexpr int total_sparse_tokens = 14;
|
||||
|
||||
// Create input buffers
|
||||
auto kv_page_offsets
|
||||
= mBufferManager->gpu(ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
||||
auto seq_lengths = mBufferManager->gpu(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
|
||||
// Shape: [num_head_kv, total_sparse_tokens] - flattened across all batches
|
||||
auto sparse_indices
|
||||
= mBufferManager->gpu(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
|
||||
= mBufferManager->gpu(ITensor::makeShape({num_head_kv, total_sparse_tokens}), nvinfer1::DataType::kINT32);
|
||||
auto sparse_indices_offsets = mBufferManager->gpu(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
|
||||
|
||||
// Create output buffers
|
||||
@ -57,7 +59,7 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
||||
ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
||||
auto seq_lengths_host = mBufferManager->pinned(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
|
||||
auto sparse_indices_host
|
||||
= mBufferManager->pinned(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
|
||||
= mBufferManager->pinned(ITensor::makeShape({num_head_kv, total_sparse_tokens}), nvinfer1::DataType::kINT32);
|
||||
auto sparse_indices_offsets_host
|
||||
= mBufferManager->pinned(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
|
||||
|
||||
@ -81,27 +83,43 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
||||
}
|
||||
|
||||
// Initialize sequence lengths
|
||||
seq_lengths_ptr[0] = 2 * tokens_per_page + 18; // 3 pages for batch 0
|
||||
seq_lengths_ptr[1] = 3 * tokens_per_page + 3; // 4 pages for batch 1
|
||||
seq_lengths_ptr[0] = 2 * tokens_per_page + 18; // 3 pages (146 tokens) for batch 0
|
||||
seq_lengths_ptr[1] = 3 * tokens_per_page + 3; // 4 pages (195 tokens) for batch 1
|
||||
|
||||
// Initialize sparse indices with different patterns for different heads
|
||||
// Shape: {total_sparse_pages, num_head_kv}
|
||||
// Each head can have its own sparse pattern
|
||||
int num_sparse_pages = 5;
|
||||
int sparse_page_indices[5][4] = {{1, 0, 0, 1}, {2, 1, 1, -1}, {-1, 2, -1, -1}, {0, 1, 2, 3}, {3, 3, 3, -1}};
|
||||
for (int page = 0; page < num_sparse_pages; ++page)
|
||||
// Initialize sparse indices with token-level indices (indices_block_size = 1)
|
||||
// Shape: [num_head_kv, total_sparse_tokens]
|
||||
// All heads have the same number of sparse tokens: 8 for batch 0, 6 for batch 1
|
||||
// Memory layout: sparse_indices_ptr[head_idx * total_sparse_tokens + token_offset]
|
||||
std::vector<std::vector<int>> sparse_tokens_per_head
|
||||
= {// Head 0: Batch 0 [10,20,70,75,90,95,100,105] -> pages [0,0,1,1,1,1,1,1] -> unique [0,1]
|
||||
// Batch 1 [64,65,128,129,192,193] -> pages [1,1,2,2,3,3] -> unique [1,2,3]
|
||||
{10, 20, 70, 75, 90, 95, 100, 105, 64, 65, 128, 129, 192, 193},
|
||||
|
||||
// Head 1: Batch 0 [5,6,65,66,130,131,135,140] -> pages [0,0,1,1,2,2,2,2] -> unique [0,1,2]
|
||||
// Batch 1 [70,71,128,129,190,191] -> pages [1,1,2,2,2,2] -> unique [1,2]
|
||||
{5, 6, 65, 66, 130, 131, 135, 140, 70, 71, 128, 129, 190, 191},
|
||||
|
||||
// Head 2: Batch 0 [20,21,80,81,85,86,90,91] -> pages [0,0,1,1,1,1,1,1] -> unique [0,1]
|
||||
// Batch 1 [64,65,66,67,68,69] -> pages [1,1,1,1,1,1] -> unique [1]
|
||||
{20, 21, 80, 81, 85, 86, 90, 91, 64, 65, 66, 67, 68, 69},
|
||||
|
||||
// Head 3: Batch 0 [70,71,72,73,74,75,76,77] -> pages [1,1,1,1,1,1,1,1] -> unique [1]
|
||||
// Batch 1 [192,193,194,195,196,197] -> pages [3,3,3,3,3,3] -> unique [3]
|
||||
{70, 71, 72, 73, 74, 75, 76, 77, 192, 193, 194, 195, 196, 197}};
|
||||
|
||||
// Fill sparse_indices_ptr using the defined data
|
||||
for (int head = 0; head < num_head_kv; ++head)
|
||||
{
|
||||
for (int head = 0; head < num_head_kv; ++head)
|
||||
for (int token_idx = 0; token_idx < total_sparse_tokens; ++token_idx)
|
||||
{
|
||||
int idx = head * num_sparse_pages + page;
|
||||
sparse_indices_ptr[idx] = sparse_page_indices[page][head];
|
||||
sparse_indices_ptr[head * total_sparse_tokens + token_idx] = sparse_tokens_per_head[head][token_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize sparse indices offsets
|
||||
sparse_indices_offsets_ptr[0] = 0; // Start of batch 0
|
||||
sparse_indices_offsets_ptr[1] = 3; // Start of batch 1 (3 sparse pages for batch 0)
|
||||
sparse_indices_offsets_ptr[2] = 5; // End (3 sparse pages for batch 1)
|
||||
// Initialize sparse indices offsets (these are per-batch offsets into the flattened array)
|
||||
sparse_indices_offsets_ptr[0] = 0; // Start of batch 0
|
||||
sparse_indices_offsets_ptr[1] = 8; // Start of batch 1 (batch 0 has 8 sparse tokens)
|
||||
sparse_indices_offsets_ptr[2] = 14; // End (batch 1 has 6 sparse tokens, total = 14)
|
||||
|
||||
// Copy data to GPU
|
||||
mBufferManager->copy(*kv_page_offsets_host, *kv_page_offsets);
|
||||
@ -112,6 +130,8 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
||||
SparseAttentionParams sparse_params;
|
||||
sparse_params.sparse_attn_indices = bufferCast<int32_t>(*sparse_indices);
|
||||
sparse_params.sparse_attn_offsets = bufferCast<int32_t>(*sparse_indices_offsets);
|
||||
sparse_params.sparse_attn_indices_block_size = 1; // Token-level indexing
|
||||
sparse_params.sparse_attn_indices_stride = total_sparse_tokens;
|
||||
|
||||
// Launch the kernel
|
||||
invokeGatherKvPageOffsets(bufferCast<int32_t>(*output_kv_page_offsets), bufferCast<int32_t>(*output_seq_lengths),
|
||||
@ -136,21 +156,49 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
||||
auto output_kv_offsets_ptr = bufferCast<int32_t>(*output_kv_page_offsets_host);
|
||||
auto output_seq_len_ptr = bufferCast<int>(*output_seq_lengths_host);
|
||||
|
||||
// Verify sequence lengths for each head and batch
|
||||
int expected_seq_lens[4][2] = {
|
||||
{tokens_per_page + 18, tokens_per_page + 3}, // Head 0: batch 0 has 2 pages, batch 1 has 0 pages
|
||||
{2 * tokens_per_page + 18, tokens_per_page + 3}, // Head 1: batch 0 has 3 pages, batch 1 has 0 pages
|
||||
{2 * tokens_per_page, tokens_per_page + 3}, // Head 2: batch 0 has 2 pages, batch 1 has 0 pages
|
||||
{tokens_per_page, 3} // Head 3: batch 0 has 2 pages, batch 1 has 0 pages
|
||||
// Define expected results for each head and batch
|
||||
// Format: {num_pages, {page_indices...}, seq_len}
|
||||
struct ExpectedResult
|
||||
{
|
||||
int num_pages;
|
||||
std::vector<int> page_indices;
|
||||
int seq_len;
|
||||
};
|
||||
|
||||
ExpectedResult expected_results[4][2] = {// Head 0
|
||||
{ // Batch 0: tokens on pages [0,1] -> 2 pages, seq_len = 2 * 64 = 128
|
||||
{2, {0, 1}, 2 * tokens_per_page},
|
||||
// Batch 1: tokens on pages [1,2,3] -> 3 pages, max_page=3 (last page)
|
||||
// seq_len = 195 - (4-3)*64 = 131 (no padding needed, max_page is last page)
|
||||
{3, {1, 2, 3}, 131}},
|
||||
// Head 1
|
||||
{// Batch 0: tokens on pages [0,1,2] -> 3 pages (all), seq_len = 146
|
||||
{3, {0, 1, 2}, 2 * tokens_per_page + 18},
|
||||
// Batch 1: tokens on pages [1,2] -> 2 pages, max_page=2 (not last page)
|
||||
// seq_len = 195 - (4-2)*64 = 67, padding: 67 + (64-3) = 128
|
||||
{2, {1, 2}, 2 * tokens_per_page}},
|
||||
// Head 2
|
||||
{// Batch 0: tokens on pages [0,1] -> 2 pages, seq_len = 128
|
||||
{2, {0, 1}, 2 * tokens_per_page},
|
||||
// Batch 1: tokens on page [1] -> 1 page, max_page=1 (not last page)
|
||||
// seq_len = 195 - (4-1)*64 = 3, padding: 3 + (64-3) = 64
|
||||
{1, {1}, tokens_per_page}},
|
||||
// Head 3
|
||||
{// Batch 0: tokens on page [1] -> 1 page, seq_len = 64
|
||||
{1, {1}, tokens_per_page},
|
||||
// Batch 1: tokens on page [3] -> 1 page, max_page=3 (last page)
|
||||
// seq_len = 195 - (4-1)*64 = 3 (no padding needed, max_page is last page)
|
||||
{1, {3}, 3}}};
|
||||
|
||||
// Verify sequence lengths for each head and batch
|
||||
for (int h = 0; h < num_head_kv; ++h)
|
||||
{
|
||||
for (int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
int seq_len_idx = h * batch_size + b;
|
||||
EXPECT_EQ(output_seq_len_ptr[seq_len_idx], expected_seq_lens[h][b])
|
||||
<< "Sequence length mismatch at head=" << h << ", batch=" << b;
|
||||
EXPECT_EQ(output_seq_len_ptr[seq_len_idx], expected_results[h][b].seq_len)
|
||||
<< "Sequence length mismatch at head=" << h << ", batch=" << b
|
||||
<< ", expected=" << expected_results[h][b].seq_len << ", got=" << output_seq_len_ptr[seq_len_idx];
|
||||
}
|
||||
}
|
||||
|
||||
@ -159,20 +207,13 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
||||
{
|
||||
for (int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
int num_sparse_pages_batch = sparse_indices_offsets_ptr[b + 1] - sparse_indices_offsets_ptr[b];
|
||||
auto const& expected = expected_results[h][b];
|
||||
|
||||
for (int d = 0; d < 2; ++d)
|
||||
{
|
||||
for (int p = 0; p < num_sparse_pages_batch; ++p)
|
||||
for (int p = 0; p < expected.num_pages; ++p)
|
||||
{
|
||||
// Calculate expected value (from the sparse index)
|
||||
int sparse_idx_global = sparse_indices_offsets_ptr[b] + p;
|
||||
int src_page_idx
|
||||
= sparse_indices_ptr[h * sparse_indices_offsets_ptr[batch_size] + sparse_idx_global];
|
||||
|
||||
if (src_page_idx == -1)
|
||||
{
|
||||
continue; // Skip invalid indices
|
||||
}
|
||||
int src_page_idx = expected.page_indices[p];
|
||||
|
||||
// Calculate output offset
|
||||
size_t output_offset = h * batch_size * 2 * max_num_pages_per_seq + b * 2 * max_num_pages_per_seq
|
||||
@ -181,7 +222,9 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
||||
int expected_value = 1000 + b * 100 + d * 10 + src_page_idx;
|
||||
|
||||
EXPECT_EQ(output_kv_offsets_ptr[output_offset], expected_value)
|
||||
<< "Mismatch at head=" << h << ", batch=" << b << ", dim=" << d << ", page=" << p;
|
||||
<< "KV page offset mismatch at head=" << h << ", batch=" << b << ", dim=" << d << ", page=" << p
|
||||
<< ", expected_page_idx=" << src_page_idx << ", expected_value=" << expected_value
|
||||
<< ", got=" << output_kv_offsets_ptr[output_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,10 +6,11 @@ This example demonstrates how to use sparse attention with TensorRT-LLM.
|
||||
|
||||
Supported sparse attention algorithms:
|
||||
- RocketKV
|
||||
- DSA
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
|
||||
python llm_sparse_attention.py --algo ROCKETKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
|
||||
```
|
||||
"""
|
||||
import argparse
|
||||
@ -70,7 +71,7 @@ def parse_arguments():
|
||||
help="The maximum chunk size for the indexer.")
|
||||
parser.add_argument("--max_seq_len",
|
||||
type=int,
|
||||
default=8192,
|
||||
default=10240,
|
||||
help="The maximum sequence length.")
|
||||
parser.add_argument("--max_batch_size",
|
||||
type=int,
|
||||
@ -83,7 +84,7 @@ def parse_arguments():
|
||||
parser.add_argument(
|
||||
"--max_num_tokens",
|
||||
type=int,
|
||||
default=8192,
|
||||
default=81920,
|
||||
help=
|
||||
"The maximum total tokens (context + generation) across all sequences in a batch."
|
||||
)
|
||||
@ -104,7 +105,7 @@ def parse_arguments():
|
||||
|
||||
# KV cache
|
||||
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
|
||||
parser.add_argument("--kv_cache_fraction", type=float, default=None)
|
||||
parser.add_argument("--kv_cache_fraction", type=float, default=0.7)
|
||||
parser.add_argument('--num_samples', type=int, default=10)
|
||||
|
||||
# Runtime
|
||||
|
||||
@ -356,7 +356,6 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
torch_compile_config=None,
|
||||
print_iter_log=args.print_iter_log,
|
||||
moe_config=MoeConfig(backend=args.moe_backend),
|
||||
)
|
||||
|
||||
@ -28,7 +28,8 @@ from transformers import AutoTokenizer
|
||||
|
||||
# Add tensorrt_llm imports
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
|
||||
RocketSparseAttentionConfig)
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# Chat templates mapping
|
||||
@ -362,6 +363,10 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
|
||||
sparse_attention_config = None
|
||||
logger.info("Using standard attention")
|
||||
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
max_batch_size=args.max_batch_size
|
||||
) if args.attention_backend == "TRTLLM" else None
|
||||
|
||||
# Initialize LLM
|
||||
llm = LLM(
|
||||
model=args.model_path,
|
||||
@ -372,8 +377,7 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
cuda_graph_config=None,
|
||||
torch_compile_config=None,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -23,7 +23,13 @@ from tensorrt_llm.bindings.internal.batch_manager import \
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from .kernel import triton_index_gather, triton_update_kt_cache
|
||||
from .kernel import (triton_bmm, triton_flatten_to_batch, triton_index_gather,
|
||||
triton_rocket_batch_to_flatten,
|
||||
triton_rocket_paged_kt_cache_bmm, triton_rocket_qk_split,
|
||||
triton_rocket_reduce_scores,
|
||||
triton_rocket_update_kt_cache_ctx,
|
||||
triton_rocket_update_kt_cache_gen, triton_softmax,
|
||||
triton_topk)
|
||||
|
||||
ModelConfig = tensorrt_llm.bindings.ModelConfig
|
||||
|
||||
@ -35,8 +41,77 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
if self.sparse_attention_config is None:
|
||||
raise ValueError("Sparse attention config is not set")
|
||||
self.prompt_budget = self.sparse_attention_config.prompt_budget
|
||||
self.window_size = self.sparse_attention_config.window_size
|
||||
self.page_size = self.sparse_attention_config.page_size
|
||||
self.topk = self.sparse_attention_config.topk
|
||||
|
||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||
|
||||
# Cumulative valid sequence lengths for query and key
|
||||
self.q_cu_seqlens_cuda = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_sequences + 1, ),
|
||||
dtype=torch.int32,
|
||||
cache_name="q_cu_seqlens_cuda",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
|
||||
self.q_cu_seqlens = torch.zeros_like(self.q_cu_seqlens_cuda,
|
||||
device='cpu',
|
||||
dtype=torch.int32)
|
||||
|
||||
self.k_cu_seqlens_cuda = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_sequences + 1, ),
|
||||
dtype=torch.int32,
|
||||
cache_name="k_cu_seqlens_cuda",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.k_cu_seqlens = torch.zeros_like(self.k_cu_seqlens_cuda,
|
||||
device='cpu',
|
||||
dtype=torch.int32)
|
||||
|
||||
# Context length of RocketKV key for each valid sequence
|
||||
self.k_context_lens = torch.empty(
|
||||
self.max_num_sequences,
|
||||
device='cpu',
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# Cumulative context lengths for each sequence
|
||||
self.context_cumsum_cuda = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_sequences + 1, ),
|
||||
dtype=torch.int32,
|
||||
cache_name="context_cumsum_cuda",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.context_cumsum = torch.zeros_like(self.context_cumsum_cuda,
|
||||
device='cpu',
|
||||
dtype=torch.int32)
|
||||
|
||||
# Sparse kv indices offsets for each sequence in context phase
|
||||
self.sparse_offsets_ctx_cuda = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_sequences + 1, ),
|
||||
dtype=torch.int32,
|
||||
cache_name="sparse_offsets_ctx_cuda",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.sparse_offsets_ctx = torch.zeros_like(self.sparse_offsets_ctx_cuda,
|
||||
device='cpu',
|
||||
dtype=torch.int32)
|
||||
|
||||
# Valid sequence indices used in sparse kv indices prediction
|
||||
self.valid_seq_indices_cuda = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_sequences, ),
|
||||
dtype=torch.int32,
|
||||
cache_name="valid_seq_indices_cuda",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
|
||||
# KT cache block offsets used in KT cache related kernels
|
||||
self.kt_cache_block_offsets = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
[
|
||||
@ -54,6 +129,41 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# Number of KT tokens for each sequence
|
||||
self.num_kt_tokens = torch.empty(
|
||||
self.max_num_sequences,
|
||||
device='cpu',
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# Cumulative KT lengths for each sequence
|
||||
self.cum_kt_lens_cuda = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_sequences + 1, ),
|
||||
dtype=torch.int32,
|
||||
cache_name="cum_kt_lens_cuda",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.cum_kt_lens = torch.zeros_like(self.cum_kt_lens_cuda,
|
||||
device='cpu',
|
||||
dtype=torch.int32)
|
||||
|
||||
# Sparse attn indices offsets for each sequence in generation phase
|
||||
self.sparse_offsets_gen_cuda = self.get_empty(
|
||||
self.cuda_graph_buffers,
|
||||
(self.max_num_sequences + 1, ),
|
||||
dtype=torch.int32,
|
||||
cache_name="sparse_offsets_gen_cuda",
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.sparse_offsets_gen = torch.zeros_like(self.sparse_offsets_gen_cuda,
|
||||
device='cpu',
|
||||
dtype=torch.int32)
|
||||
|
||||
# Maximum number of KT tokens
|
||||
self.max_kt_tokens = (self.max_seq_len + self.page_size -
|
||||
1) // self.page_size
|
||||
|
||||
@property
|
||||
def kt_tokens_per_block(self) -> Optional[int]:
|
||||
"""
|
||||
@ -100,81 +210,85 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
||||
self.host_kt_cache_block_offsets[:self.num_seqs],
|
||||
non_blocking=True)
|
||||
|
||||
# -------------------------------- Context phase --------------------------------
|
||||
self.context_cumsum[1:self.num_contexts + 1] = torch.cumsum(
|
||||
self.prompt_lens_cpu[:self.num_contexts], dim=0)
|
||||
self.context_cumsum_cuda[:self.num_contexts + 1].copy_(
|
||||
self.context_cumsum[:self.num_contexts + 1], non_blocking=True)
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def convert_token_to_page_sparse_indices(
|
||||
sparse_attn_indices: torch.Tensor, sparse_attn_offsets: torch.Tensor,
|
||||
metadata: 'TrtllmAttentionMetadata'
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert token-based sparse attention indices to page-based sparse attention indices.
|
||||
# We need to filter out sequences that are too short to skip sparse kv indices prediction
|
||||
valid_mask = self.prompt_lens_cpu[:self.
|
||||
num_contexts] >= self.prompt_budget
|
||||
valid_seq_indices = torch.where(valid_mask)[0]
|
||||
invalid_seq_indices = torch.where(~valid_mask)[0]
|
||||
valid_batch_size = len(valid_seq_indices)
|
||||
self.valid_seq_indices_cuda[:valid_batch_size].copy_(valid_seq_indices,
|
||||
non_blocking=True)
|
||||
|
||||
Args:
|
||||
sparse_attn_indices: Token-based indices with shape [num_tokens, num_kv_heads]
|
||||
sparse_attn_offsets: Offsets with shape [batch_size+1] indicating token boundaries for each batch
|
||||
metadata: Attention metadata containing tokens_per_block (page_size)
|
||||
# Only consider sequences that are long enough for sparse kv indices prediction in context phase
|
||||
self.k_context_lens[:valid_batch_size] = self.prompt_lens_cpu[
|
||||
valid_seq_indices] - self.window_size
|
||||
if valid_batch_size > 0:
|
||||
# Maximum context length of RocketKV key for valid sequences for padding
|
||||
self.max_rocket_k_ctx_len = self.k_context_lens[:
|
||||
valid_batch_size].max(
|
||||
).item()
|
||||
else:
|
||||
self.max_rocket_k_ctx_len = 0
|
||||
|
||||
Returns:
|
||||
Tuple of (page_indices, page_offsets):
|
||||
- page_indices: Page-based indices with shape [num_pages, num_kv_heads]
|
||||
- page_offsets: Updated offsets with shape [batch_size+1] indicating page boundaries for each batch
|
||||
sparse_counts_ctx = torch.zeros(self.num_contexts,
|
||||
dtype=torch.int32,
|
||||
device='cpu')
|
||||
sparse_counts_ctx[valid_seq_indices] = self.prompt_budget
|
||||
sparse_counts_ctx[invalid_seq_indices] = self.prompt_lens_cpu[
|
||||
invalid_seq_indices]
|
||||
|
||||
Example:
|
||||
If sparse_attn_indices first dimension is [1, 30, 67] and page_size=32,
|
||||
the result will be [0, 2] (token 1 -> page 0, token 30 -> page 0, token 67 -> page 2)
|
||||
"""
|
||||
page_size = metadata.tokens_per_block
|
||||
batch_size = sparse_attn_offsets.size(0) - 1
|
||||
num_kv_heads = sparse_attn_indices.size(1)
|
||||
self.sparse_offsets_ctx[1:self.num_contexts + 1] = torch.cumsum(
|
||||
sparse_counts_ctx, dim=0)
|
||||
self.sparse_offsets_ctx_cuda[:self.num_contexts + 1].copy_(
|
||||
self.sparse_offsets_ctx[:self.num_contexts + 1], non_blocking=True)
|
||||
|
||||
# Convert token indices to page indices
|
||||
page_indices = sparse_attn_indices // page_size
|
||||
self.q_cu_seqlens[:valid_batch_size + 1] = torch.arange(
|
||||
valid_batch_size + 1, device='cpu',
|
||||
dtype=torch.int32) * self.window_size
|
||||
self.q_cu_seqlens_cuda[:valid_batch_size + 1].copy_(
|
||||
self.q_cu_seqlens[:valid_batch_size + 1], non_blocking=True)
|
||||
|
||||
# Process each batch and each kv_head separately to remove duplicates
|
||||
new_page_indices_list = []
|
||||
new_offsets = torch.zeros_like(sparse_attn_offsets)
|
||||
self.k_cu_seqlens[1:valid_batch_size + 1] = torch.cumsum(
|
||||
self.k_context_lens[:valid_batch_size], dim=0)
|
||||
self.k_cu_seqlens_cuda[:valid_batch_size + 1].copy_(
|
||||
self.k_cu_seqlens[:valid_batch_size + 1], non_blocking=True)
|
||||
|
||||
current_offset = 0
|
||||
for batch_idx in range(batch_size):
|
||||
start_idx = sparse_attn_offsets[batch_idx]
|
||||
end_idx = sparse_attn_offsets[batch_idx + 1]
|
||||
self.valid_batch_size = valid_batch_size
|
||||
self.total_sparse_ctx_indices = self.sparse_offsets_ctx[
|
||||
self.num_contexts].item()
|
||||
|
||||
if start_idx >= end_idx:
|
||||
# Empty batch
|
||||
new_offsets[batch_idx + 1] = current_offset
|
||||
continue
|
||||
# -------------------------------- Generation phase --------------------------------
|
||||
self.num_kt_tokens[:self.num_generations] = (
|
||||
self.kv_lens[self.num_contexts:self.num_seqs] + self.page_size -
|
||||
1) // self.page_size
|
||||
|
||||
batch_page_indices = page_indices[
|
||||
start_idx:end_idx] # [num_tokens_in_batch, num_kv_heads]
|
||||
self.cum_kt_lens[1:self.num_generations + 1] = torch.cumsum(
|
||||
self.num_kt_tokens[:self.num_generations], dim=0)
|
||||
self.cum_kt_lens_cuda[:self.num_generations + 1].copy_(
|
||||
self.cum_kt_lens[:self.num_generations + 1], non_blocking=True)
|
||||
|
||||
# For each kv_head, remove duplicates while preserving order
|
||||
batch_unique_pages = []
|
||||
for head_idx in range(num_kv_heads):
|
||||
head_pages = batch_page_indices[:, head_idx]
|
||||
unique_pages = torch.unique(head_pages, sorted=False)
|
||||
batch_unique_pages.append(unique_pages)
|
||||
self.total_kt_tokens = self.num_generations * self.max_kt_tokens
|
||||
|
||||
# Find the maximum length among all heads for this batch
|
||||
max_len = max(pages.size(0) for pages in batch_unique_pages)
|
||||
topk_tensor = torch.tensor(self.topk, dtype=torch.int32)
|
||||
|
||||
if max_len > 0:
|
||||
batch_result = torch.full((max_len, num_kv_heads),
|
||||
fill_value=-1,
|
||||
dtype=page_indices.dtype,
|
||||
device=page_indices.device)
|
||||
# Some sequences may have less than topk KT tokens
|
||||
# We need to use the minimum of topk and the number of KT tokens
|
||||
sparse_counts_gen = torch.minimum(
|
||||
topk_tensor, self.num_kt_tokens[:self.num_generations])
|
||||
|
||||
for head_idx in range(num_kv_heads):
|
||||
unique_pages = batch_unique_pages[head_idx]
|
||||
batch_result[:unique_pages.size(0), head_idx] = unique_pages
|
||||
self.sparse_offsets_gen[1:self.num_generations + 1] = torch.cumsum(
|
||||
sparse_counts_gen[:self.num_generations], dim=0)
|
||||
self.sparse_offsets_gen_cuda[:self.num_generations + 1].copy_(
|
||||
self.sparse_offsets_gen[:self.num_generations + 1],
|
||||
non_blocking=True)
|
||||
|
||||
new_page_indices_list.append(batch_result)
|
||||
current_offset += max_len
|
||||
|
||||
new_offsets[batch_idx + 1] = current_offset
|
||||
|
||||
new_page_indices = torch.cat(new_page_indices_list, dim=0)
|
||||
|
||||
return new_page_indices, new_offsets
|
||||
self.total_sparse_gen_indices = self.topk * self.num_generations
|
||||
|
||||
|
||||
class RocketTrtllmAttention(TrtllmAttention):
|
||||
@ -213,85 +327,6 @@ class RocketTrtllmAttention(TrtllmAttention):
|
||||
self.kernel_size = sparse_attention_config.kernel_size
|
||||
self.page_size = sparse_attention_config.page_size
|
||||
|
||||
def sparse_attn_predict(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
**kwargs,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Predict sparse attention indices.
|
||||
For RocketKV:
|
||||
- Generation phase: predict RocketKV sparse attention indices
|
||||
|
||||
Returns:
|
||||
- sparse_attn_indices: [total_selected_indices, num_kv_heads]
|
||||
- sparse_attn_offsets: [batch_size + 1] with cumulative indices count
|
||||
"""
|
||||
if k is None:
|
||||
q, k, _ = q.split([
|
||||
self.num_heads * self.head_dim, self.num_kv_heads *
|
||||
self.head_dim, self.num_kv_heads * self.head_dim
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
num_contexts = metadata.num_contexts
|
||||
num_generations = metadata.num_generations
|
||||
seq_lens = metadata.seq_lens
|
||||
seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens
|
||||
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
|
||||
|
||||
sparse_attn_indices = []
|
||||
sparse_attn_offsets = [0]
|
||||
|
||||
q_offset = 0
|
||||
k_offset = 0
|
||||
|
||||
for i in range(num_contexts + num_generations):
|
||||
seq_len = seq_lens[i].item()
|
||||
seq_len_kv = seq_lens_kv[i].item()
|
||||
|
||||
if seq_len <= 0 or seq_len_kv <= 0:
|
||||
assert False, "Invalid sequence length"
|
||||
|
||||
single_q = q[q_offset:q_offset + seq_len]
|
||||
single_k = k[k_offset:k_offset + seq_len_kv]
|
||||
|
||||
single_q = single_q.view(1, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
single_k = single_k.view(1, seq_len_kv, self.num_kv_heads,
|
||||
self.head_dim)
|
||||
|
||||
past_seen_token = past_seen_tokens[i]
|
||||
# Generation phase: RocketKV sparse attention indices
|
||||
if i >= num_contexts:
|
||||
_sparse_attn_indices = self._rocketkv_selection(
|
||||
single_q, single_k, past_seen_token, metadata, i)
|
||||
if _sparse_attn_indices is not None:
|
||||
sparse_attn_indices.append(
|
||||
_sparse_attn_indices.squeeze(0)) # [topk, num_kv_heads]
|
||||
sparse_attn_offsets.append(sparse_attn_offsets[-1] +
|
||||
_sparse_attn_indices.size(1))
|
||||
else:
|
||||
sparse_attn_offsets.append(sparse_attn_offsets[-1])
|
||||
|
||||
q_offset += seq_len
|
||||
k_offset += seq_len_kv
|
||||
|
||||
if len(sparse_attn_indices) == 0:
|
||||
sparse_attn_indices, sparse_attn_offsets = None, None
|
||||
else:
|
||||
sparse_attn_indices = torch.cat(sparse_attn_indices,
|
||||
dim=0).to(torch.int32)
|
||||
sparse_attn_offsets = torch.tensor(sparse_attn_offsets,
|
||||
dtype=torch.int32).to(q.device)
|
||||
sparse_attn_indices, sparse_attn_offsets = convert_token_to_page_sparse_indices(
|
||||
sparse_attn_indices, sparse_attn_offsets, metadata)
|
||||
sparse_attn_indices = sparse_attn_indices.transpose(0,
|
||||
1).contiguous()
|
||||
return sparse_attn_indices, sparse_attn_offsets
|
||||
|
||||
def sparse_kv_predict(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@ -300,229 +335,203 @@ class RocketTrtllmAttention(TrtllmAttention):
|
||||
**kwargs,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Predict sparse kv indices.
|
||||
Predict sparse KV indices using optimized SnapKV algorithm.
|
||||
|
||||
For RocketKV:
|
||||
- Context phase: predict RocketKV sparse kv indices
|
||||
|
||||
Returns:
|
||||
- flattened_indices: [total_selected_indices, num_kv_heads]
|
||||
- batch_offsets: [batch_size + 1] with cumulative indices count
|
||||
Uses a single Triton kernel to compute attention scores between observation window
|
||||
and prefix tokens, then selects the most important prefix tokens directly.
|
||||
"""
|
||||
|
||||
num_ctx_tokens = metadata.num_ctx_tokens
|
||||
if num_ctx_tokens == 0:
|
||||
return None, None
|
||||
|
||||
if k is None:
|
||||
q, k, _ = q.split([
|
||||
self.num_heads * self.head_dim, self.num_kv_heads *
|
||||
self.head_dim, self.num_kv_heads * self.head_dim
|
||||
],
|
||||
dim=-1)
|
||||
|
||||
num_contexts = metadata.num_contexts
|
||||
num_generations = metadata.num_generations
|
||||
seq_lens = metadata.seq_lens
|
||||
seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens
|
||||
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
|
||||
|
||||
sparse_kv_indices = []
|
||||
sparse_kv_offsets = [0]
|
||||
|
||||
q_offset = 0
|
||||
k_offset = 0
|
||||
|
||||
for i in range(num_contexts + num_generations):
|
||||
seq_len = seq_lens[i].item()
|
||||
seq_len_kv = seq_lens_kv[i].item()
|
||||
|
||||
if seq_len <= 0 or seq_len_kv <= 0:
|
||||
assert False, "Invalid sequence length"
|
||||
|
||||
single_q = q[q_offset:q_offset + seq_len]
|
||||
single_k = k[k_offset:k_offset + seq_len_kv]
|
||||
|
||||
single_q = single_q.view(1, seq_len, self.num_heads,
|
||||
self.head_dim).transpose(1, 2)
|
||||
single_k = single_k.view(1, seq_len_kv, self.num_kv_heads,
|
||||
self.head_dim)
|
||||
|
||||
past_seen_token = past_seen_tokens[i]
|
||||
if i < num_contexts:
|
||||
# Context phase: SnapKV sparse kv indices
|
||||
_sparse_kv_indices = self._get_snapkv_indices(
|
||||
single_q, single_k, past_seen_token, metadata, i)
|
||||
if _sparse_kv_indices is not None:
|
||||
sparse_kv_indices.append(
|
||||
_sparse_kv_indices.squeeze(0)) # [budget, num_kv_heads]
|
||||
sparse_kv_offsets.append(sparse_kv_offsets[-1] +
|
||||
_sparse_kv_indices.size(1))
|
||||
else:
|
||||
sparse_kv_offsets.append(sparse_kv_offsets[-1])
|
||||
|
||||
q_offset += seq_len
|
||||
k_offset += seq_len_kv
|
||||
|
||||
if len(sparse_kv_indices) == 0:
|
||||
sparse_kv_indices, sparse_kv_offsets = None, None
|
||||
qkv_input = q[:num_ctx_tokens]
|
||||
else:
|
||||
sparse_kv_indices = torch.cat(sparse_kv_indices,
|
||||
dim=0).to(torch.int32)
|
||||
sparse_kv_indices = sparse_kv_indices.transpose(0, 1).contiguous()
|
||||
sparse_kv_offsets = torch.tensor(sparse_kv_offsets,
|
||||
dtype=torch.int32).to(q.device)
|
||||
return sparse_kv_indices, sparse_kv_offsets
|
||||
qkv_input = torch.cat([q, k], dim=1)
|
||||
|
||||
def _get_snapkv_indices(self, q: Tensor, k: Tensor, past_seen_token: int,
|
||||
metadata: RocketTrtllmAttentionMetadata,
|
||||
sample_idx: int) -> Optional[Tensor]:
|
||||
"""
|
||||
Get SnapKV sparse kv indices from the input sequence for context phase.
|
||||
The shape of output is (1, prompt_budget, num_kv_heads)
|
||||
"""
|
||||
bsz = 1
|
||||
seq_len = k.size(1) # k shape: (1, seq_len, num_kv_heads, head_dim)
|
||||
if metadata.valid_batch_size > 0:
|
||||
q_window, k_context = triton_rocket_qk_split(
|
||||
qkv_input,
|
||||
metadata.prompt_lens_cuda,
|
||||
metadata.context_cumsum_cuda,
|
||||
metadata.valid_seq_indices_cuda,
|
||||
metadata.k_cu_seqlens_cuda,
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.window_size,
|
||||
metadata.valid_batch_size,
|
||||
)
|
||||
|
||||
# If the sequence length is less than the prompt budget, do not enable sparse kv cache
|
||||
if seq_len <= self.prompt_budget:
|
||||
return None
|
||||
scores = triton_bmm(q_window,
|
||||
k_context,
|
||||
metadata.q_cu_seqlens_cuda,
|
||||
metadata.k_cu_seqlens_cuda,
|
||||
metadata.valid_batch_size,
|
||||
causal=False)
|
||||
|
||||
# Use last window_size tokens as observation
|
||||
# (1, num_heads, window_size, head_dim)
|
||||
q_obs = q[:, :, -self.window_size:]
|
||||
# (1, num_kv_heads, seq_len, head_dim)
|
||||
k_pre = repeat_kv(k.transpose(1, 2),
|
||||
self.num_heads // self.num_kv_heads)
|
||||
scores = triton_softmax(scores, metadata.k_cu_seqlens_cuda,
|
||||
metadata.valid_batch_size)
|
||||
|
||||
dist = (torch.arange(0, self.window_size, device=q.device)[:, None] -
|
||||
torch.arange(0, seq_len, device=q.device)[None, :] + seq_len -
|
||||
self.window_size)
|
||||
attention_mask = (dist >= 0)
|
||||
# scores: [num_heads, window_size, total_k_tokens] -> [num_kv_heads, total_k_tokens]
|
||||
scores = scores.view(self.num_kv_heads,
|
||||
self.num_heads // self.num_kv_heads,
|
||||
self.window_size, -1).sum(dim=(1, 2))
|
||||
|
||||
score = torch.matmul(q_obs, k_pre.transpose(-1, -2)) / math.sqrt(
|
||||
self.head_dim)
|
||||
# Reshape scores to handle variable length sequences with padding using Triton
|
||||
# scores: [num_kv_heads, total_k_tokens] -> [valid_batch_size, num_kv_heads, padding_size]
|
||||
scores = triton_flatten_to_batch(scores, metadata.k_cu_seqlens_cuda,
|
||||
metadata.valid_batch_size,
|
||||
metadata.max_rocket_k_ctx_len)
|
||||
|
||||
score = torch.masked_fill(
|
||||
score,
|
||||
attention_mask.view(1, 1, self.window_size, seq_len) == False,
|
||||
torch.scalar_tensor(float("-inf"),
|
||||
device=score.device,
|
||||
dtype=score.dtype))
|
||||
scores = torch.nn.functional.max_pool1d(
|
||||
scores,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.kernel_size // 2,
|
||||
stride=1)
|
||||
|
||||
score = torch.nn.functional.softmax(score, dim=-1)
|
||||
selected_prefix_indices = scores.topk(
|
||||
self.prompt_budget - self.window_size,
|
||||
dim=-1).indices.sort().values.to(torch.int32)
|
||||
else:
|
||||
selected_prefix_indices = torch.empty(
|
||||
(0, self.num_kv_heads, self.prompt_budget - self.window_size),
|
||||
device=qkv_input.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
score = torch.masked_fill(
|
||||
score,
|
||||
attention_mask.view(1, 1, self.window_size, seq_len) == False,
|
||||
torch.scalar_tensor(0, device=score.device, dtype=score.dtype))
|
||||
sparse_kv_offsets = metadata.sparse_offsets_ctx_cuda[:metadata.
|
||||
num_contexts + 1]
|
||||
|
||||
score = score[:, :, -self.window_size:, :-self.window_size].sum(dim=-2)
|
||||
# Flatten sparse indices
|
||||
sparse_kv_indices = triton_rocket_batch_to_flatten(
|
||||
selected_prefix_indices, metadata.prompt_lens_cuda,
|
||||
metadata.valid_seq_indices_cuda, sparse_kv_offsets,
|
||||
metadata.num_contexts, metadata.total_sparse_ctx_indices,
|
||||
self.window_size, self.prompt_budget)
|
||||
|
||||
score = score.view(bsz, self.num_kv_heads,
|
||||
self.num_heads // self.num_kv_heads, -1).sum(dim=2)
|
||||
score = torch.nn.functional.max_pool1d(score,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.kernel_size // 2,
|
||||
stride=1)
|
||||
|
||||
# Select top important tokens from prefix
|
||||
prefix_len = seq_len - self.window_size
|
||||
selected_prefix_indices = score.topk(self.prompt_budget -
|
||||
self.window_size,
|
||||
dim=-1).indices.sort().values
|
||||
|
||||
# Combine selected prefix indices with window indices
|
||||
window_indices = torch.arange(
|
||||
prefix_len, seq_len,
|
||||
device=k.device).unsqueeze(0).unsqueeze(0).expand(
|
||||
bsz, self.num_kv_heads, -1)
|
||||
selected_indices = torch.cat([selected_prefix_indices, window_indices],
|
||||
dim=-1).transpose(1, 2)
|
||||
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
k_snap = triton_index_gather(k, selected_indices)
|
||||
# Update KT cache
|
||||
kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(
|
||||
self.layer_idx)
|
||||
k_snap_len = torch.clamp(
|
||||
metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1],
|
||||
max=self.prompt_budget).int()
|
||||
triton_update_kt_cache(
|
||||
k_snap.squeeze(0).contiguous(),
|
||||
|
||||
triton_rocket_update_kt_cache_ctx(
|
||||
qkv_input.contiguous(),
|
||||
kt_cache_tensor,
|
||||
metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1],
|
||||
k_snap_len,
|
||||
metadata.kt_cache_block_offsets[:metadata.num_contexts],
|
||||
metadata.context_cumsum_cuda[:metadata.num_contexts + 1],
|
||||
sparse_kv_indices,
|
||||
sparse_kv_offsets,
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
metadata.kt_tokens_per_block,
|
||||
metadata.kv_cache_manager.max_kt_blocks_per_seq,
|
||||
update=False)
|
||||
)
|
||||
|
||||
return selected_indices
|
||||
# Reduce overhead of post processing
|
||||
if metadata.valid_batch_size == 0:
|
||||
return None, None
|
||||
|
||||
def _rocketkv_selection(self, q: Tensor, k: Tensor, past_seen_token: int,
|
||||
metadata: RocketTrtllmAttentionMetadata,
|
||||
sample_idx: int) -> Tensor:
|
||||
"""
|
||||
Implement RocketKV's two-stage selection process for generation phase.
|
||||
The shape of output is (1, topk, num_kv_heads)
|
||||
"""
|
||||
bsz = 1
|
||||
q_len = q.size(2)
|
||||
return sparse_kv_indices, sparse_kv_offsets
|
||||
|
||||
# Helper functions
|
||||
def _gather(t: Tensor, dim: int, i: Tensor) -> Tensor:
|
||||
dim += (dim < 0) * t.ndim
|
||||
return t.gather(
|
||||
dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1:]))
|
||||
@torch.compile(dynamic=True)
|
||||
def preprocess_for_gen(
|
||||
self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||
metadata: RocketTrtllmAttentionMetadata
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if k is None:
|
||||
qkv_input = q[metadata.num_ctx_tokens:]
|
||||
q_hidden_size = self.num_heads * self.head_dim
|
||||
k_hidden_size = self.num_kv_heads * self.head_dim
|
||||
q = qkv_input[:, :q_hidden_size]
|
||||
k = qkv_input[:, q_hidden_size:q_hidden_size + k_hidden_size]
|
||||
else:
|
||||
q = q[metadata.num_ctx_tokens:]
|
||||
k = k[metadata.num_ctx_tokens:]
|
||||
|
||||
@torch.compile(disable=not torch.cuda.is_available())
|
||||
def _scaled_softmax(x: Tensor, divscale: Tensor | float,
|
||||
dim: int) -> Tensor:
|
||||
return torch.softmax(x / divscale, dim=dim)
|
||||
q = q.view(-1, self.num_kv_heads, self.num_heads // self.num_kv_heads,
|
||||
self.head_dim)
|
||||
|
||||
q_abs = torch.abs(q)
|
||||
q_mask = torch.zeros_like(q)
|
||||
|
||||
i1 = torch.topk(q_abs.mean(dim=2, keepdim=True), self.topr,
|
||||
dim=-1).indices
|
||||
|
||||
q_mask.scatter_(-1, i1.expand_as(q[..., :self.topr]), 1)
|
||||
|
||||
q_valid = q * q_mask
|
||||
|
||||
dim_pos = torch.where(q_valid.sum(dim=2) > 0, self.head_dim,
|
||||
0).to(torch.int32)
|
||||
|
||||
return q_valid, k, dim_pos
|
||||
|
||||
def sparse_attn_predict(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
**kwargs,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if metadata.num_generations == 0:
|
||||
return None, None
|
||||
|
||||
q, k, dim_pos = self.preprocess_for_gen(q, k, metadata)
|
||||
|
||||
# Get KT cache for key-token matching
|
||||
kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(
|
||||
self.layer_idx)
|
||||
target_seq_len = past_seen_token + 1 # +1 for current token
|
||||
|
||||
# Update KT cache
|
||||
kt_states = triton_update_kt_cache(
|
||||
k.squeeze(0).contiguous(), kt_cache_tensor,
|
||||
metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1],
|
||||
metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1],
|
||||
self.page_size, metadata.kt_tokens_per_block,
|
||||
metadata.kv_cache_manager.max_kt_blocks_per_seq)
|
||||
kt_states = kt_states.unsqueeze(0).permute(0, 2, 3, 1)
|
||||
# Update KT cache with new key values
|
||||
triton_rocket_update_kt_cache_gen(
|
||||
k,
|
||||
kt_cache_tensor,
|
||||
metadata.kt_cache_block_offsets[metadata.num_contexts:],
|
||||
metadata.kv_lens_cuda_runtime[metadata.num_contexts:],
|
||||
metadata.page_size,
|
||||
metadata.kt_tokens_per_block,
|
||||
metadata.kv_cache_manager.max_kt_blocks_per_seq,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
# Reshape query for multi-head processing
|
||||
qi = q.view(bsz, self.num_kv_heads, self.num_heads // self.num_kv_heads,
|
||||
q_len, self.head_dim)
|
||||
qi_abs = torch.abs(qi)
|
||||
# Perform BMM with updated cache
|
||||
scores = triton_rocket_paged_kt_cache_bmm(
|
||||
q,
|
||||
kt_cache_tensor,
|
||||
metadata.kt_cache_block_offsets[metadata.num_contexts:],
|
||||
dim_pos,
|
||||
metadata.kv_lens_cuda_runtime[metadata.num_contexts:],
|
||||
metadata.cum_kt_lens_cuda,
|
||||
metadata.page_size,
|
||||
metadata.kt_tokens_per_block,
|
||||
metadata.kv_cache_manager.max_kt_blocks_per_seq,
|
||||
metadata.total_kt_tokens,
|
||||
)
|
||||
|
||||
# Top-r selection on query features
|
||||
i1 = torch.topk(qi_abs.mean(dim=2, keepdim=True), self.topr,
|
||||
dim=-1).indices
|
||||
qi_hat = _gather(qi, -1, i1)
|
||||
scores = triton_softmax(scores, metadata.cum_kt_lens_cuda,
|
||||
metadata.num_generations)
|
||||
|
||||
# Generate signed indices for key-token matching
|
||||
i1_sign = torch.where(
|
||||
qi_hat.sum(dim=2, keepdim=True) > 0, i1 + self.head_dim,
|
||||
i1).transpose(-1, -2)
|
||||
# Mean over num_heads_per_kv for each batch separately
|
||||
scores = triton_rocket_reduce_scores(
|
||||
scores,
|
||||
metadata.cum_kt_lens_cuda,
|
||||
metadata.num_generations,
|
||||
self.num_kv_heads,
|
||||
self.num_heads // self.num_kv_heads,
|
||||
)
|
||||
|
||||
# Gather key tokens and compute attention scores
|
||||
kt_hat = _gather(kt_states.unsqueeze(2), -2, i1_sign)
|
||||
qk_hat = qi_hat @ kt_hat
|
||||
qk_hat = qk_hat.repeat_interleave(self.page_size,
|
||||
dim=-1)[:, :, :, :, :target_seq_len]
|
||||
scale = torch.sqrt(self.head_dim *
|
||||
torch.abs(qi_hat).sum(dim=-1, keepdim=True) /
|
||||
qi_abs.sum(dim=-1, keepdim=True))
|
||||
sparse_attn_offsets = metadata.sparse_offsets_gen_cuda[:metadata.
|
||||
num_generations +
|
||||
1]
|
||||
|
||||
# (1, num_kv_heads, num_heads, target_seq_len)
|
||||
s_hat = _scaled_softmax(qk_hat, scale, dim=-1)
|
||||
selected_indices = triton_topk(scores, metadata.cum_kt_lens_cuda,
|
||||
sparse_attn_offsets,
|
||||
metadata.total_sparse_gen_indices,
|
||||
metadata.topk)
|
||||
|
||||
topk = min(self.topk, target_seq_len)
|
||||
i2 = torch.topk(s_hat.mean(dim=2, keepdim=True), topk, dim=-1).indices
|
||||
|
||||
iKV = i2[:, :, 0, 0, :].transpose(1, 2) # (1, topk, num_kv_heads)
|
||||
|
||||
return iKV
|
||||
return selected_indices, sparse_attn_offsets
|
||||
|
||||
|
||||
class RocketVanillaAttentionMetadata(VanillaAttentionMetadata):
|
||||
@ -920,6 +929,7 @@ class RocketKVCacheManager(KVCacheManager):
|
||||
max_num_draft_tokens: int = 0,
|
||||
use_mrope: bool = False,
|
||||
max_beam_width: int = 1,
|
||||
num_extra_decoding_steps: int = 0,
|
||||
):
|
||||
requests = super().add_dummy_requests(
|
||||
request_ids=request_ids,
|
||||
@ -929,6 +939,7 @@ class RocketKVCacheManager(KVCacheManager):
|
||||
max_num_draft_tokens=max_num_draft_tokens,
|
||||
use_mrope=use_mrope,
|
||||
max_beam_width=max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
)
|
||||
if prepare_resource:
|
||||
for req in requests:
|
||||
|
||||
@ -197,6 +197,7 @@ class TrtllmAttentionWrapper:
|
||||
sparse_kv_offsets: Optional[torch.Tensor] = None,
|
||||
sparse_attn_indices: Optional[torch.Tensor] = None,
|
||||
sparse_attn_offsets: Optional[torch.Tensor] = None,
|
||||
sparse_attn_indices_block_size: int = 1,
|
||||
sparse_mla_topk: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
@ -241,6 +242,7 @@ class TrtllmAttentionWrapper:
|
||||
sparse_kv_offsets (torch.Tensor): The batch offsets for the sparse KV indices, with shape of (num_contexts + 1) on GPU.
|
||||
sparse_attn_indices (torch.Tensor): The sparse indices for the attention layer, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
|
||||
sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU.
|
||||
sparse_attn_indices_block_size (int): The granularity of the sparse attention indices, used by block sparse attention.
|
||||
sparse_mla_topk (int): The topk for the sparse MLA, used by DSA attention.
|
||||
"""
|
||||
self.layer_idx = layer_idx
|
||||
@ -283,6 +285,7 @@ class TrtllmAttentionWrapper:
|
||||
self.sparse_kv_offsets = sparse_kv_offsets
|
||||
self.sparse_attn_indices = sparse_attn_indices
|
||||
self.sparse_attn_offsets = sparse_attn_offsets
|
||||
self.sparse_attn_indices_block_size = sparse_attn_indices_block_size
|
||||
self.sparse_mla_topk = sparse_mla_topk
|
||||
if max_sequence_length > self.rope_params.max_positions:
|
||||
self.rope_params.max_positions = max_sequence_length
|
||||
@ -525,6 +528,7 @@ class TrtllmAttentionWrapper:
|
||||
self.sparse_kv_offsets,
|
||||
self.sparse_attn_indices,
|
||||
self.sparse_attn_offsets,
|
||||
self.sparse_attn_indices_block_size,
|
||||
self.sparse_mla_topk,
|
||||
cu_q_seqlens,
|
||||
cu_kv_seqlens,
|
||||
@ -1308,11 +1312,14 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
)
|
||||
|
||||
sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None
|
||||
sparse_attn_indices_block_size = 1
|
||||
if self.sparse_attention_config is not None:
|
||||
sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(
|
||||
q, k, metadata, **kwargs)
|
||||
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
|
||||
q, k, metadata, **kwargs)
|
||||
sparse_attn_indices_block_size = self.sparse_attention_config.get_indices_block_size(
|
||||
)
|
||||
|
||||
self.wrapper.plan(
|
||||
layer_idx=self.get_local_layer_idx(metadata),
|
||||
@ -1366,8 +1373,10 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
sparse_kv_offsets=sparse_kv_offsets,
|
||||
sparse_attn_indices=sparse_attn_indices,
|
||||
sparse_attn_offsets=sparse_attn_offsets,
|
||||
sparse_attn_indices_block_size=sparse_attn_indices_block_size,
|
||||
sparse_mla_topk=metadata.sparse_mla_topk if hasattr(
|
||||
metadata, 'sparse_mla_topk') else 0)
|
||||
metadata, 'sparse_mla_topk') else 0,
|
||||
)
|
||||
out_dtype = None
|
||||
if out_scale is not None:
|
||||
if use_nvfp4_output:
|
||||
@ -1589,18 +1598,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
self.mla_params.v_head_dim,
|
||||
)
|
||||
|
||||
def sparse_attn_predict(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
**kwargs,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Predict sparse attn indices. It's implemented in the derived class.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sparse_kv_predict(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@ -1613,6 +1610,18 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def sparse_attn_predict(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
metadata: TrtllmAttentionMetadata,
|
||||
**kwargs,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Predict sparse attn indices. It's implemented in the derived class.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def mla_rope_generation(
|
||||
self,
|
||||
fused_q: torch.Tensor,
|
||||
|
||||
@ -215,6 +215,9 @@ class BaseSparseAttentionConfig(StrictBaseModel):
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_indices_block_size(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
|
||||
"""
|
||||
@ -238,6 +241,9 @@ class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
|
||||
def supports_backend(self, backend: str) -> bool:
|
||||
return backend == "pytorch"
|
||||
|
||||
def get_indices_block_size(self) -> int:
|
||||
return self.page_size
|
||||
|
||||
|
||||
class DeepSeekSparseAttentionConfig(BaseSparseAttentionConfig):
|
||||
"""
|
||||
|
||||
@ -2,10 +2,18 @@ import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
|
||||
from tensorrt_llm._torch.attention_backend.sparse.rocket import (
|
||||
RocketKVCacheManager, RocketTrtllmAttention, RocketTrtllmAttentionMetadata,
|
||||
RocketVanillaAttention, RocketVanillaAttentionMetadata)
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
|
||||
RocketSparseAttentionConfig)
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", ["pytorch"])
|
||||
@ -25,6 +33,11 @@ def test_model(backend, model_name, attention_backend):
|
||||
prompt_budget=2048,
|
||||
)
|
||||
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1, 2, 4, 8, 16],
|
||||
enable_padding=True,
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
backend=backend,
|
||||
@ -32,10 +45,10 @@ def test_model(backend, model_name, attention_backend):
|
||||
attn_backend=attention_backend,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=8192,
|
||||
max_num_tokens=8192,
|
||||
cuda_graph_config=
|
||||
None, # sparse attention does not support cuda graph now
|
||||
max_seq_len=20480,
|
||||
max_num_tokens=81920,
|
||||
cuda_graph_config=None
|
||||
if attention_backend == "VANILLA" else cuda_graph_config,
|
||||
)
|
||||
|
||||
inputs, references = [], []
|
||||
@ -75,6 +88,559 @@ def test_model(backend, model_name, attention_backend):
|
||||
assert acc >= 0.9, 'accuracy test of rocketkv sparse attention failed'
|
||||
|
||||
|
||||
def create_rocket_kv_cache_manager(num_layers, num_kv_heads, head_dim,
|
||||
tokens_per_block, max_seq_len,
|
||||
max_batch_size, dtype, sparse_attn_config):
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
num_blocks = 100
|
||||
kv_cache_config = KvCacheConfig(max_tokens=num_blocks * tokens_per_block,
|
||||
enable_block_reuse=False)
|
||||
|
||||
kv_cache_manager = RocketKVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=dtype,
|
||||
sparse_attn_config=sparse_attn_config,
|
||||
)
|
||||
return kv_cache_manager
|
||||
|
||||
|
||||
def create_test_metadata(seq_lens, num_contexts, past_seen_tokens, request_ids,
|
||||
kv_cache_manager, sparse_attn_config, metadata_cls):
|
||||
prompt_lens = []
|
||||
for i, (seq_len, past_token) in enumerate(zip(seq_lens, past_seen_tokens)):
|
||||
if i < num_contexts:
|
||||
prompt_lens.append(seq_len)
|
||||
else:
|
||||
prompt_lens.append(past_token + seq_len)
|
||||
|
||||
metadata = metadata_cls(
|
||||
seq_lens=torch.tensor(seq_lens, dtype=torch.int),
|
||||
num_contexts=num_contexts,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True, num_cached_tokens_per_seq=past_seen_tokens),
|
||||
max_num_requests=len(seq_lens),
|
||||
max_num_sequences=len(seq_lens),
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
sparse_attention_config=sparse_attn_config,
|
||||
)
|
||||
metadata.prepare()
|
||||
return metadata
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,num_contexts",
|
||||
[
|
||||
(1, 1), # bs=1
|
||||
(4, 4), # bs=2, context only (2 contexts)
|
||||
(6, 3), # bs=6, mixed (3 contexts + 3 generations)
|
||||
])
|
||||
def test_sparse_kv_predict(batch_size, num_contexts):
|
||||
"""
|
||||
Test sparse_kv_predict against vanilla _get_snapkv_indices.
|
||||
|
||||
This test verifies that the batched implementation produces the same results
|
||||
as the single-request implementation for SnapKV sparse attention.
|
||||
"""
|
||||
|
||||
# Test configuration
|
||||
num_heads = 32
|
||||
num_kv_heads = 8
|
||||
head_dim = 128
|
||||
device = torch.device('cuda')
|
||||
dtype = torch.bfloat16
|
||||
|
||||
sparse_attn_config = RocketSparseAttentionConfig(
|
||||
window_size=32,
|
||||
kernel_size=3,
|
||||
prompt_budget=256,
|
||||
page_size=3,
|
||||
)
|
||||
|
||||
# Create sequence lengths - mix short and long sequences in context phase
|
||||
seq_lens = []
|
||||
past_seen_tokens = []
|
||||
for i in range(batch_size):
|
||||
if i < num_contexts:
|
||||
# Context phase: mix sequences shorter and longer than prompt_budget
|
||||
if i % 2 == 1 and batch_size > 1:
|
||||
# Short sequence: seq_len < prompt_budget
|
||||
seq_lens.append(
|
||||
torch.randint(sparse_attn_config.prompt_budget // 2,
|
||||
sparse_attn_config.prompt_budget - 10,
|
||||
(1, )).item())
|
||||
else:
|
||||
# Long sequence: seq_len > prompt_budget
|
||||
seq_lens.append(
|
||||
torch.randint(sparse_attn_config.prompt_budget,
|
||||
sparse_attn_config.prompt_budget + 200,
|
||||
(1, )).item())
|
||||
past_seen_tokens.append(0)
|
||||
else:
|
||||
# Generation phase: single token
|
||||
seq_lens.append(1)
|
||||
past_seen_tokens.append(torch.randint(100, 200, (1, )).item())
|
||||
|
||||
request_ids = list(range(batch_size))
|
||||
|
||||
num_layers = 1
|
||||
tokens_per_block = 64
|
||||
max_seq_len = 4096
|
||||
if dtype == torch.float16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
vanilla_tokens_per_block = max_seq_len # Each sequence in one block
|
||||
|
||||
trtllm_kv_cache_manager = create_rocket_kv_cache_manager(
|
||||
num_layers, num_kv_heads, head_dim, tokens_per_block, max_seq_len,
|
||||
batch_size, kv_cache_dtype, sparse_attn_config)
|
||||
vanilla_kv_cache_manager = create_rocket_kv_cache_manager(
|
||||
num_layers, num_kv_heads, head_dim, vanilla_tokens_per_block,
|
||||
max_seq_len, batch_size, kv_cache_dtype, sparse_attn_config)
|
||||
|
||||
# Add dummy requests to both cache managers
|
||||
token_nums = [
|
||||
seq_len + past_token
|
||||
for seq_len, past_token in zip(seq_lens, past_seen_tokens)
|
||||
]
|
||||
trtllm_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
vanilla_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
trtllm_attn = RocketTrtllmAttention(
|
||||
layer_idx=0,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
sparse_attention_config=sparse_attn_config,
|
||||
)
|
||||
|
||||
vanilla_attn = RocketVanillaAttention(
|
||||
layer_idx=0,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
sparse_attention_config=sparse_attn_config,
|
||||
)
|
||||
|
||||
trtllm_metadata = create_test_metadata(seq_lens, num_contexts,
|
||||
past_seen_tokens, request_ids,
|
||||
trtllm_kv_cache_manager,
|
||||
sparse_attn_config,
|
||||
RocketTrtllmAttentionMetadata)
|
||||
vanilla_metadata = create_test_metadata(seq_lens, num_contexts,
|
||||
past_seen_tokens, request_ids,
|
||||
vanilla_kv_cache_manager,
|
||||
sparse_attn_config,
|
||||
RocketVanillaAttentionMetadata)
|
||||
|
||||
total_tokens = sum(seq_lens)
|
||||
qkv = torch.randn(total_tokens, (num_heads + 2 * num_kv_heads) * head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
trtllm_sparse_kv_indices, trtllm_sparse_kv_offsets = trtllm_attn.sparse_kv_predict(
|
||||
qkv, None, trtllm_metadata)
|
||||
|
||||
vanilla_sparse_kv_indices_list = []
|
||||
offset = 0
|
||||
for i in range(num_contexts):
|
||||
seq_len = seq_lens[i]
|
||||
single_qkv = qkv[offset:offset + seq_len]
|
||||
q, k, _ = single_qkv.split([
|
||||
num_heads * head_dim, num_kv_heads * head_dim,
|
||||
num_kv_heads * head_dim
|
||||
],
|
||||
dim=-1)
|
||||
q = q.view(1, seq_len, num_heads, head_dim).transpose(1, 2)
|
||||
k = k.view(1, seq_len, num_kv_heads, head_dim)
|
||||
|
||||
if seq_len <= sparse_attn_config.prompt_budget:
|
||||
# Short sequences: vanilla returns None, but trtllm returns [0, 1, ..., seq_len-1]
|
||||
# Generate expected indices for comparison
|
||||
short_indices = torch.arange(seq_len,
|
||||
device=device,
|
||||
dtype=torch.int32).unsqueeze(0).expand(
|
||||
num_kv_heads, -1)
|
||||
vanilla_sparse_kv_indices_list.append(short_indices)
|
||||
else:
|
||||
vanilla_indices = vanilla_attn._get_snapkv_indices(q, k, i)
|
||||
if vanilla_indices is not None:
|
||||
vanilla_indices = vanilla_indices.squeeze(0).transpose(
|
||||
0, 1).contiguous()
|
||||
vanilla_sparse_kv_indices_list.append(vanilla_indices)
|
||||
|
||||
offset += seq_len
|
||||
|
||||
if len(vanilla_sparse_kv_indices_list) > 0:
|
||||
vanilla_sparse_kv_indices = torch.cat(vanilla_sparse_kv_indices_list,
|
||||
dim=-1).contiguous()
|
||||
else:
|
||||
vanilla_sparse_kv_indices = None
|
||||
|
||||
# Compare results
|
||||
if trtllm_sparse_kv_indices is not None:
|
||||
assert vanilla_sparse_kv_indices is not None, "Vanilla should also produce indices"
|
||||
|
||||
assert trtllm_sparse_kv_indices.shape == vanilla_sparse_kv_indices.shape, \
|
||||
f"Shape mismatch: {trtllm_sparse_kv_indices.shape} vs {vanilla_sparse_kv_indices.shape}"
|
||||
|
||||
# Check indices overlap per batch and per head
|
||||
num_kv_heads = trtllm_sparse_kv_indices.shape[0]
|
||||
|
||||
# trtllm_sparse_kv_offsets tells where each batch's indices start/end
|
||||
trtllm_offsets = trtllm_sparse_kv_offsets.cpu().tolist()
|
||||
|
||||
overlap_ratios = []
|
||||
batch_overlap_details = []
|
||||
|
||||
for batch_idx in range(num_contexts):
|
||||
start_idx = trtllm_offsets[batch_idx]
|
||||
end_idx = trtllm_offsets[batch_idx + 1]
|
||||
end_idx - start_idx
|
||||
|
||||
batch_overlaps = []
|
||||
for head_idx in range(num_kv_heads):
|
||||
trtllm_batch = trtllm_sparse_kv_indices[
|
||||
head_idx, start_idx:end_idx].cpu().tolist()
|
||||
vanilla_batch = vanilla_sparse_kv_indices[
|
||||
head_idx, start_idx:end_idx].cpu().tolist()
|
||||
|
||||
trtllm_set = set(trtllm_batch)
|
||||
vanilla_set = set(vanilla_batch)
|
||||
|
||||
# Calculate overlap
|
||||
overlap = len(vanilla_set & trtllm_set)
|
||||
overlap_ratio = overlap / len(vanilla_set) if len(
|
||||
vanilla_set) > 0 else 1.0
|
||||
batch_overlaps.append(overlap_ratio)
|
||||
overlap_ratios.append(overlap_ratio)
|
||||
|
||||
avg_batch_overlap = sum(batch_overlaps) / len(batch_overlaps)
|
||||
batch_overlap_details.append(
|
||||
f"Batch {batch_idx}: {avg_batch_overlap:.4f}")
|
||||
|
||||
avg_overlap_ratio = sum(overlap_ratios) / len(overlap_ratios)
|
||||
print(f"Average overlap ratio: {avg_overlap_ratio:.4f}")
|
||||
print(f"Per-batch average: {batch_overlap_details}")
|
||||
|
||||
assert avg_overlap_ratio >= 0.98, \
|
||||
f"Indices overlap ratio {avg_overlap_ratio:.4f} is too low (< 0.98)"
|
||||
else:
|
||||
assert vanilla_sparse_kv_indices is None, "Both should return None when no sparse attention is needed"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,num_contexts",
|
||||
[
|
||||
(1, 0), # bs=1, generation only (1 generation)
|
||||
(2, 0), # bs=2, generation only (2 generations)
|
||||
(3, 0), # bs=3, generation only (3 generations)
|
||||
(5, 3), # bs=5, mixed (3 contexts + 2 generations)
|
||||
(6, 2), # bs=6, mixed (2 ctx + 4 gen)
|
||||
])
|
||||
def test_sparse_attn_predict(batch_size, num_contexts):
|
||||
"""Test sparse_attn_predict against vanilla _rocketkv_selection."""
|
||||
num_generations = batch_size - num_contexts
|
||||
|
||||
# Test configuration
|
||||
num_heads = 32
|
||||
num_kv_heads = 8
|
||||
head_dim = 128
|
||||
device = torch.device('cuda')
|
||||
dtype = torch.bfloat16
|
||||
|
||||
sparse_attn_config_vanilla = RocketSparseAttentionConfig(
|
||||
window_size=32,
|
||||
kernel_size=3,
|
||||
prompt_budget=256,
|
||||
page_size=3,
|
||||
topk=128,
|
||||
topr=96,
|
||||
)
|
||||
sparse_attn_config_trtllm = RocketSparseAttentionConfig(
|
||||
window_size=32,
|
||||
kernel_size=3,
|
||||
prompt_budget=256,
|
||||
page_size=3,
|
||||
topk=43,
|
||||
topr=96,
|
||||
)
|
||||
|
||||
# Create sequence lengths
|
||||
seq_lens = []
|
||||
past_seen_tokens = []
|
||||
for i in range(batch_size):
|
||||
if i < num_contexts:
|
||||
# Context phase: longer sequences
|
||||
seq_lens.append(torch.randint(300, 400, (1, )).item())
|
||||
past_seen_tokens.append(0)
|
||||
else:
|
||||
# Generation phase: single token
|
||||
seq_lens.append(1)
|
||||
# 128 is the minimum number of tokens for shape alignment
|
||||
past_seen_tokens.append(torch.randint(128, 300, (1, )).item())
|
||||
|
||||
request_ids = list(range(batch_size))
|
||||
|
||||
num_layers = 1
|
||||
tokens_per_block = 64
|
||||
max_seq_len = 4096
|
||||
if dtype == torch.float16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
vanilla_tokens_per_block = max_seq_len # Each sequence in one block
|
||||
|
||||
trtllm_kv_cache_manager = create_rocket_kv_cache_manager(
|
||||
num_layers, num_kv_heads, head_dim, tokens_per_block, max_seq_len,
|
||||
batch_size, kv_cache_dtype, sparse_attn_config_trtllm)
|
||||
vanilla_kv_cache_manager = create_rocket_kv_cache_manager(
|
||||
num_layers, num_kv_heads, head_dim, vanilla_tokens_per_block,
|
||||
max_seq_len, batch_size, kv_cache_dtype, sparse_attn_config_vanilla)
|
||||
|
||||
# Add dummy requests to both cache managers
|
||||
token_nums = [
|
||||
seq_len + past_token
|
||||
for seq_len, past_token in zip(seq_lens, past_seen_tokens)
|
||||
]
|
||||
trtllm_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
vanilla_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
trtllm_attn = RocketTrtllmAttention(
|
||||
layer_idx=0,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
sparse_attention_config=sparse_attn_config_trtllm,
|
||||
)
|
||||
|
||||
vanilla_attn = RocketVanillaAttention(
|
||||
layer_idx=0,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
num_kv_heads=num_kv_heads,
|
||||
sparse_attention_config=sparse_attn_config_vanilla,
|
||||
)
|
||||
|
||||
trtllm_metadata = create_test_metadata(seq_lens, num_contexts,
|
||||
past_seen_tokens, request_ids,
|
||||
trtllm_kv_cache_manager,
|
||||
sparse_attn_config_trtllm,
|
||||
RocketTrtllmAttentionMetadata)
|
||||
vanilla_metadata = create_test_metadata(seq_lens, num_contexts,
|
||||
past_seen_tokens, request_ids,
|
||||
vanilla_kv_cache_manager,
|
||||
sparse_attn_config_vanilla,
|
||||
RocketVanillaAttentionMetadata)
|
||||
|
||||
total_tokens = sum(seq_lens)
|
||||
qkv = torch.randn(total_tokens, (num_heads + 2 * num_kv_heads) * head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
trtllm_kt_buf = trtllm_kv_cache_manager.get_kt_buffers(layer_idx)
|
||||
vanilla_kt_buf = vanilla_kv_cache_manager.get_kt_buffers(layer_idx)
|
||||
|
||||
torch.nn.init.normal_(trtllm_kt_buf)
|
||||
|
||||
# Map trtllm data to vanilla based on block offsets
|
||||
# TRTLLM: (num_blocks, kt_tokens_per_block, num_kv_heads, 2*head_dim)
|
||||
# VANILLA: (num_blocks, kt_tokens_per_block_vanilla, num_kv_heads, 2*head_dim)
|
||||
trtllm_kt_tokens_per_block = trtllm_kv_cache_manager.kt_tokens_per_block
|
||||
vanilla_kv_cache_manager.kt_tokens_per_block
|
||||
|
||||
trtllm_block_offsets = trtllm_metadata.kt_cache_block_offsets
|
||||
vanilla_block_offsets = vanilla_metadata.kt_cache_block_offsets
|
||||
|
||||
for req_idx in range(num_contexts, batch_size):
|
||||
# Get the number of KT tokens for this request
|
||||
past_token = past_seen_tokens[req_idx]
|
||||
num_kt_tokens = (past_token + 1 +
|
||||
sparse_attn_config_trtllm.page_size -
|
||||
1) // sparse_attn_config_trtllm.page_size
|
||||
|
||||
# Get block offsets for this request
|
||||
trtllm_blocks = trtllm_block_offsets[req_idx]
|
||||
vanilla_blocks = vanilla_block_offsets[req_idx]
|
||||
|
||||
# Copy data from trtllm blocks to vanilla blocks
|
||||
kt_token_idx = 0
|
||||
vanilla_block_idx = 0
|
||||
|
||||
# For trtllm: iterate through blocks and copy KT tokens
|
||||
for trtllm_block_local_idx in range(len(trtllm_blocks)):
|
||||
if kt_token_idx >= num_kt_tokens:
|
||||
break
|
||||
|
||||
trtllm_block = trtllm_blocks[trtllm_block_local_idx]
|
||||
if trtllm_block < 0:
|
||||
break
|
||||
|
||||
# How many KT tokens in this trtllm block
|
||||
kt_tokens_in_this_block = min(trtllm_kt_tokens_per_block,
|
||||
num_kt_tokens - kt_token_idx)
|
||||
|
||||
# Copy to vanilla buffer
|
||||
vanilla_block = vanilla_blocks[vanilla_block_idx]
|
||||
if vanilla_block >= 0:
|
||||
vanilla_kt_buf[vanilla_block, kt_token_idx:kt_token_idx +
|
||||
kt_tokens_in_this_block].copy_(trtllm_kt_buf[
|
||||
trtllm_block, :kt_tokens_in_this_block])
|
||||
|
||||
kt_token_idx += kt_tokens_in_this_block
|
||||
|
||||
trtllm_sparse_attn_indices, trtllm_sparse_attn_offsets = trtllm_attn.sparse_attn_predict(
|
||||
qkv, None, trtllm_metadata)
|
||||
|
||||
vanilla_sparse_attn_indices_list = []
|
||||
offset = sum(seq_lens[:num_contexts]) # Skip context tokens
|
||||
|
||||
for i in range(num_contexts, batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
single_qkv = qkv[offset:offset + seq_len]
|
||||
q, k, _ = single_qkv.split([
|
||||
num_heads * head_dim, num_kv_heads * head_dim,
|
||||
num_kv_heads * head_dim
|
||||
],
|
||||
dim=-1)
|
||||
q = q.view(1, seq_len, num_heads, head_dim).transpose(1, 2)
|
||||
k = k.view(1, seq_len, num_kv_heads, head_dim)
|
||||
|
||||
past_seen_token = past_seen_tokens[i]
|
||||
vanilla_indices = vanilla_attn._rocketkv_selection(
|
||||
q, k, vanilla_metadata, past_seen_token, i)
|
||||
vanilla_sparse_attn_indices_list.append(vanilla_indices.squeeze(0))
|
||||
offset += seq_len
|
||||
|
||||
if trtllm_sparse_attn_indices is not None:
|
||||
assert len(vanilla_sparse_attn_indices_list
|
||||
) > 0, "Vanilla should also produce indices"
|
||||
|
||||
vanilla_sparse_attn_indices = torch.cat(
|
||||
vanilla_sparse_attn_indices_list, dim=0).transpose(0,
|
||||
1).contiguous()
|
||||
|
||||
# Apply interleave operation to trtllm indices
|
||||
# For each head, multiply indices by page_size and expand to include all tokens in each page
|
||||
page_size = sparse_attn_config_trtllm.page_size
|
||||
num_kv_heads, total_indices = trtllm_sparse_attn_indices.shape
|
||||
|
||||
interleaved_indices_list = []
|
||||
for head_idx in range(num_kv_heads):
|
||||
head_indices = trtllm_sparse_attn_indices[
|
||||
head_idx] # Shape: [total_indices]
|
||||
|
||||
page_starts = head_indices * page_size # Shape: [total_indices]
|
||||
|
||||
expanded_indices = []
|
||||
for page_start in page_starts:
|
||||
page_indices = torch.arange(page_start,
|
||||
page_start + page_size,
|
||||
device=page_starts.device)
|
||||
expanded_indices.append(page_indices)
|
||||
|
||||
head_interleaved = torch.cat(expanded_indices, dim=0)
|
||||
|
||||
# Slice to match vanilla shape
|
||||
target_length = vanilla_sparse_attn_indices.shape[1]
|
||||
head_interleaved = head_interleaved[:target_length]
|
||||
|
||||
interleaved_indices_list.append(head_interleaved)
|
||||
|
||||
# Stack all heads
|
||||
trtllm_sparse_attn_indices = torch.stack(interleaved_indices_list,
|
||||
dim=0)
|
||||
|
||||
assert trtllm_sparse_attn_indices.shape == vanilla_sparse_attn_indices.shape, \
|
||||
f"Shape mismatch: {trtllm_sparse_attn_indices.shape} vs {vanilla_sparse_attn_indices.shape}"
|
||||
|
||||
trtllm_sparse_attn_indices = trtllm_sparse_attn_indices.sort(
|
||||
dim=-1).values
|
||||
vanilla_sparse_attn_indices = vanilla_sparse_attn_indices.sort(
|
||||
dim=-1).values
|
||||
|
||||
# Check indices overlap per batch and per head
|
||||
num_kv_heads = trtllm_sparse_attn_indices.shape[0]
|
||||
|
||||
trtllm_offsets = trtllm_sparse_attn_offsets.cpu().tolist()
|
||||
|
||||
overlap_ratios = []
|
||||
batch_overlap_details = []
|
||||
|
||||
num_generations = batch_size - num_contexts
|
||||
for batch_idx in range(num_generations):
|
||||
start_idx = trtllm_offsets[batch_idx]
|
||||
end_idx = trtllm_offsets[batch_idx + 1]
|
||||
end_idx - start_idx
|
||||
|
||||
batch_overlaps = []
|
||||
for head_idx in range(num_kv_heads):
|
||||
trtllm_batch = trtllm_sparse_attn_indices[
|
||||
head_idx, start_idx:end_idx].cpu().tolist()
|
||||
vanilla_batch = vanilla_sparse_attn_indices[
|
||||
head_idx, start_idx:end_idx].cpu().tolist()
|
||||
|
||||
trtllm_set = set(trtllm_batch)
|
||||
vanilla_set = set(vanilla_batch)
|
||||
|
||||
# Calculate overlap
|
||||
overlap = len(vanilla_set & trtllm_set)
|
||||
overlap_ratio = overlap / len(vanilla_set) if len(
|
||||
vanilla_set) > 0 else 1.0
|
||||
batch_overlaps.append(overlap_ratio)
|
||||
overlap_ratios.append(overlap_ratio)
|
||||
|
||||
avg_batch_overlap = sum(batch_overlaps) / len(batch_overlaps)
|
||||
batch_overlap_details.append(
|
||||
f"Batch {batch_idx}: {avg_batch_overlap:.4f}")
|
||||
|
||||
avg_overlap_ratio = sum(overlap_ratios) / len(overlap_ratios)
|
||||
print(f"Average overlap ratio: {avg_overlap_ratio:.4f}")
|
||||
print(f"Per-batch average: {batch_overlap_details}")
|
||||
|
||||
threshold = 0.94
|
||||
assert avg_overlap_ratio >= threshold, \
|
||||
f"Indices overlap ratio {avg_overlap_ratio:.4f} is too low (< {threshold})"
|
||||
else:
|
||||
assert len(
|
||||
vanilla_sparse_attn_indices_list
|
||||
) == 0, "Both should return None when no sparse attention is needed"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# RocketKV e2e tests
|
||||
print("=== Testing RocketKV E2E tests ===")
|
||||
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "VANILLA")
|
||||
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "TRTLLM")
|
||||
|
||||
# Unit tests for sparse_kv_predict
|
||||
print("\n=== Testing sparse_kv_predict ===")
|
||||
test_sparse_kv_predict(1, 1) # bs=1, context only
|
||||
test_sparse_kv_predict(2, 2) # bs=2, context only
|
||||
test_sparse_kv_predict(6, 3) # bs=6, mixed (3 ctx + 3 gen)
|
||||
|
||||
# Unit tests for sparse_attn_predict
|
||||
print("\n=== Testing sparse_attn_predict ===")
|
||||
test_sparse_attn_predict(1, 0) # bs=1, generation only
|
||||
test_sparse_attn_predict(2, 0) # bs=2, generation only
|
||||
test_sparse_attn_predict(3, 0) # bs=3, generation only
|
||||
test_sparse_attn_predict(5, 3) # bs=5, mixed (3 ctx + 2 gen)
|
||||
test_sparse_attn_predict(6, 2) # bs=6, mixed (2 ctx + 4 gen)
|
||||
|
||||
413
tests/unittest/_torch/attention/sparse/test_triton_bmm.py
Normal file
413
tests/unittest/_torch/attention/sparse/test_triton_bmm.py
Normal file
@ -0,0 +1,413 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.sparse.kernel import (
|
||||
triton_bmm,
|
||||
triton_rocket_paged_kt_cache_bmm,
|
||||
)
|
||||
|
||||
|
||||
def pytorch_reference_bmm(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
q_cu_seqlens: torch.Tensor,
|
||||
k_cu_seqlens: torch.Tensor,
|
||||
batch_size: int,
|
||||
sm_scale: float = None,
|
||||
causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_q_heads, total_q_tokens, head_dim = q.shape
|
||||
num_k_heads, total_k_tokens, _ = k.shape
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / math.sqrt(head_dim)
|
||||
|
||||
# Compute q_len_per_seq
|
||||
q_len_per_seq = total_q_tokens // batch_size
|
||||
|
||||
scores = torch.full(
|
||||
(num_q_heads, q_len_per_seq, total_k_tokens),
|
||||
float("-inf"),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# Process each batch
|
||||
for batch_idx in range(batch_size):
|
||||
q_start = q_cu_seqlens[batch_idx].item()
|
||||
q_end = q_cu_seqlens[batch_idx + 1].item()
|
||||
k_start = k_cu_seqlens[batch_idx].item()
|
||||
k_end = k_cu_seqlens[batch_idx + 1].item()
|
||||
|
||||
q_seqlen = q_end - q_start
|
||||
k_seqlen = k_end - k_start
|
||||
|
||||
if q_seqlen <= 0 or k_seqlen <= 0:
|
||||
continue
|
||||
|
||||
q_batch = q[:, q_start:q_end, :] # [num_q_heads, q_seqlen, head_dim]
|
||||
|
||||
num_heads_per_kv = num_q_heads // num_k_heads
|
||||
|
||||
for head_idx in range(num_q_heads):
|
||||
k_head_idx = head_idx // num_heads_per_kv
|
||||
k_batch = k[k_head_idx, k_start:k_end, :] # [k_seqlen, head_dim]
|
||||
|
||||
qk = torch.matmul(q_batch[head_idx], k_batch.T) * sm_scale
|
||||
|
||||
if causal:
|
||||
causal_mask = torch.triu(
|
||||
torch.ones(q_seqlen, k_seqlen, device=q.device, dtype=torch.bool), diagonal=1
|
||||
)
|
||||
qk = qk.masked_fill(causal_mask, float("-inf"))
|
||||
|
||||
scores[head_idx, :q_seqlen, k_start:k_end] = qk
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def create_kt_cache_from_k(
|
||||
k: torch.Tensor,
|
||||
kv_lens: torch.Tensor,
|
||||
kt_page_size: int,
|
||||
tokens_per_block: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""
|
||||
Create paged KT cache tensor from continuous K tensor.
|
||||
|
||||
Args:
|
||||
k: Key tensor [num_gen_tokens, num_kv_heads * head_dim]
|
||||
kv_lens: Sequence lengths [num_gen_tokens]
|
||||
kt_page_size: Page size for KT tokens
|
||||
tokens_per_block: Tokens per cache block
|
||||
num_kv_heads: Number of KV heads
|
||||
head_dim: Head dimension
|
||||
|
||||
Returns:
|
||||
kt_cache_tensor: KT cache [num_blocks, tokens_per_block, num_kv_heads, 2*head_dim]
|
||||
kt_cache_block_offsets: Block offsets [num_gen_tokens, max_kt_blocks_per_seq]
|
||||
max_kt_blocks_per_seq: Maximum KT blocks per sequence
|
||||
"""
|
||||
num_gen_tokens = k.shape[0]
|
||||
device = k.device
|
||||
dtype = k.dtype
|
||||
|
||||
# Calculate number of kt tokens per sequence
|
||||
num_kt_tokens_per_seq = [
|
||||
(kv_len.item() + kt_page_size - 1) // kt_page_size for kv_len in kv_lens
|
||||
]
|
||||
max_kt_tokens = max(num_kt_tokens_per_seq)
|
||||
max_kt_blocks_per_seq = (max_kt_tokens + tokens_per_block - 1) // tokens_per_block
|
||||
|
||||
# Calculate total number of blocks needed
|
||||
total_blocks_needed = sum(
|
||||
(kt_tokens + tokens_per_block - 1) // tokens_per_block
|
||||
for kt_tokens in num_kt_tokens_per_seq
|
||||
)
|
||||
|
||||
# Create KT cache tensor
|
||||
kt_cache_tensor = torch.zeros(
|
||||
(total_blocks_needed, tokens_per_block, num_kv_heads, 2 * head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Create block offsets tensor
|
||||
kt_cache_block_offsets = torch.full(
|
||||
(num_gen_tokens, max_kt_blocks_per_seq), -1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Fill KT cache and block offsets
|
||||
current_block_idx = 0
|
||||
|
||||
for seq_idx in range(num_gen_tokens):
|
||||
kv_len = kv_lens[seq_idx].item()
|
||||
num_kt_tokens = num_kt_tokens_per_seq[seq_idx]
|
||||
|
||||
# Reshape k for this sequence: [num_kv_heads, head_dim]
|
||||
k_seq = k[seq_idx].view(num_kv_heads, head_dim)
|
||||
|
||||
# Process each kt token (page)
|
||||
for kt_idx in range(num_kt_tokens):
|
||||
page_start = kt_idx * kt_page_size
|
||||
|
||||
# For simplicity, we use the first token in the page as representative
|
||||
# In real usage, this would be min/max over the page
|
||||
# Here we just replicate the first token's value for testing
|
||||
token_idx = page_start
|
||||
if token_idx < kv_len:
|
||||
k_val = k_seq # [num_kv_heads, head_dim]
|
||||
|
||||
# Store k_min and k_max (for testing, we use same value)
|
||||
kt_min = k_val
|
||||
kt_max = k_val
|
||||
|
||||
# Determine which block this kt token belongs to
|
||||
block_offset = kt_idx // tokens_per_block
|
||||
token_offset_in_block = kt_idx % tokens_per_block
|
||||
|
||||
# Assign block index if not already assigned
|
||||
if kt_cache_block_offsets[seq_idx, block_offset] < 0:
|
||||
kt_cache_block_offsets[seq_idx, block_offset] = current_block_idx
|
||||
current_block_idx += 1
|
||||
|
||||
block_idx = kt_cache_block_offsets[seq_idx, block_offset].item()
|
||||
|
||||
# Store in cache: [block, token_in_block, head, 2*head_dim]
|
||||
kt_cache_tensor[block_idx, token_offset_in_block, :, :head_dim] = kt_min
|
||||
kt_cache_tensor[block_idx, token_offset_in_block, :, head_dim:] = kt_max
|
||||
|
||||
return kt_cache_tensor, kt_cache_block_offsets, max_kt_blocks_per_seq
|
||||
|
||||
|
||||
def pytorch_reference_paged_kt_cache_bmm(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
dim_pos: torch.Tensor,
|
||||
kv_lens: torch.Tensor,
|
||||
kt_page_size: int,
|
||||
sm_scale: float = None,
|
||||
) -> torch.Tensor:
|
||||
num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim = q.shape
|
||||
total_num_heads = num_kv_heads * num_heads_per_kv
|
||||
device = q.device
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = 1.0 / math.sqrt(head_dim)
|
||||
|
||||
max_kt_tokens = max((kv_len.item() + kt_page_size - 1) // kt_page_size for kv_len in kv_lens)
|
||||
total_kt_tokens = num_gen_tokens * max_kt_tokens
|
||||
|
||||
scores = torch.zeros((total_num_heads, 1, total_kt_tokens), dtype=torch.float32, device=device)
|
||||
|
||||
# Process each generation token
|
||||
for batch_idx in range(num_gen_tokens):
|
||||
kv_len = kv_lens[batch_idx].item()
|
||||
num_kt_tokens = (kv_len + kt_page_size - 1) // kt_page_size
|
||||
|
||||
q_batch = q[batch_idx] # [num_kv_heads, num_heads_per_kv, head_dim]
|
||||
k_batch = k[batch_idx].view(num_kv_heads, head_dim) # [num_kv_heads, head_dim]
|
||||
dim_pos_batch = dim_pos[batch_idx] # [num_kv_heads, head_dim]
|
||||
|
||||
output_offset = batch_idx * max_kt_tokens
|
||||
|
||||
for kv_head_idx in range(num_kv_heads):
|
||||
for q_head_idx in range(num_heads_per_kv):
|
||||
global_head_idx = kv_head_idx * num_heads_per_kv + q_head_idx
|
||||
|
||||
q_vec = q_batch[kv_head_idx, q_head_idx] # [head_dim]
|
||||
k_vec = k_batch[kv_head_idx] # [head_dim]
|
||||
dim_pos_vec = dim_pos_batch[kv_head_idx] # [head_dim]
|
||||
|
||||
# Simulate KT selection based on dim_pos
|
||||
k_selected = torch.where(dim_pos_vec > 0, k_vec, k_vec)
|
||||
|
||||
# Compute score for each kt token (simplified)
|
||||
for kt_idx in range(num_kt_tokens):
|
||||
score = torch.dot(q_vec, k_selected) * sm_scale
|
||||
scores[global_head_idx, 0, output_offset + kt_idx] = score
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,q_len_per_seq,k_lens,num_q_heads,num_kv_heads,head_dim,causal",
|
||||
[
|
||||
# Single batch
|
||||
(1, 32, [128], 8, 8, 128, False),
|
||||
(1, 32, [128], 8, 8, 128, True),
|
||||
# Multiple batches with different k_len
|
||||
(3, 32, [64, 128, 256], 8, 8, 128, False),
|
||||
(4, 16, [100, 200, 150, 300], 32, 8, 128, False),
|
||||
# Edge cases
|
||||
(2, 1, [10, 20], 8, 8, 128, False), # q_len=1
|
||||
(2, 64, [64, 128], 16, 4, 64, True), # Different head_dim with causal
|
||||
],
|
||||
)
|
||||
def test_triton_bmm(batch_size, q_len_per_seq, k_lens, num_q_heads, num_kv_heads, head_dim, causal):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float32
|
||||
|
||||
total_q_tokens = batch_size * q_len_per_seq
|
||||
q_cu_seqlens = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * q_len_per_seq
|
||||
|
||||
k_lens_tensor = torch.tensor(k_lens, dtype=torch.int32, device=device)
|
||||
k_cu_seqlens = torch.cat(
|
||||
[torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(k_lens_tensor, dim=0)]
|
||||
)
|
||||
total_k_tokens = k_cu_seqlens[-1].item()
|
||||
|
||||
q = torch.randn((num_q_heads, total_q_tokens, head_dim), dtype=dtype, device=device)
|
||||
k = torch.randn((num_kv_heads, total_k_tokens, head_dim), dtype=dtype, device=device)
|
||||
|
||||
triton_scores = triton_bmm(
|
||||
q=q,
|
||||
k=k,
|
||||
q_cu_seqlens=q_cu_seqlens,
|
||||
k_cu_seqlens=k_cu_seqlens,
|
||||
batch_size=batch_size,
|
||||
sm_scale=None,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
reference_scores = pytorch_reference_bmm(
|
||||
q=q,
|
||||
k=k,
|
||||
q_cu_seqlens=q_cu_seqlens,
|
||||
k_cu_seqlens=k_cu_seqlens,
|
||||
batch_size=batch_size,
|
||||
sm_scale=None,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
# Handle -inf values separately
|
||||
triton_finite = torch.isfinite(triton_scores)
|
||||
reference_finite = torch.isfinite(reference_scores)
|
||||
|
||||
# Check that inf/finite masks match
|
||||
assert torch.all(triton_finite == reference_finite), (
|
||||
"Finite/infinite mask mismatch between Triton and reference"
|
||||
)
|
||||
|
||||
# Compare finite values
|
||||
if triton_finite.any():
|
||||
max_diff = torch.max(
|
||||
torch.abs(triton_scores[triton_finite] - reference_scores[reference_finite])
|
||||
).item()
|
||||
|
||||
print(f"Max absolute difference: {max_diff:.6f}")
|
||||
|
||||
assert max_diff < 0.01, f"Max difference {max_diff} exceeds threshold"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,kv_lens,num_kv_heads,num_heads_per_kv,head_dim,kt_page_size,tokens_per_block",
|
||||
[
|
||||
# Single batch
|
||||
(1, [128], 8, 4, 128, 3, 64),
|
||||
# Multiple batches with different kv_len
|
||||
(3, [100, 200, 150], 8, 4, 128, 3, 64),
|
||||
],
|
||||
)
|
||||
def test_triton_rocket_paged_kt_cache_bmm(
|
||||
batch_size, kv_lens, num_kv_heads, num_heads_per_kv, head_dim, kt_page_size, tokens_per_block
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
num_gen_tokens = batch_size
|
||||
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32, device=device)
|
||||
|
||||
# Create Q tensor: [num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim]
|
||||
q = torch.randn(
|
||||
(num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
|
||||
# Create K tensor for reference: [num_gen_tokens, num_kv_heads * head_dim]
|
||||
k = torch.randn((num_gen_tokens, num_kv_heads * head_dim), dtype=dtype, device=device)
|
||||
|
||||
# Create dim_pos: [num_gen_tokens, num_kv_heads, head_dim]
|
||||
# Randomly set some dimensions to head_dim (positive) and others to 0
|
||||
dim_pos = torch.zeros(
|
||||
(num_gen_tokens, num_kv_heads, head_dim), dtype=torch.int32, device=device
|
||||
)
|
||||
for i in range(num_gen_tokens):
|
||||
for j in range(num_kv_heads):
|
||||
num_positive = torch.randint(0, head_dim, (1,)).item()
|
||||
positive_indices = torch.randperm(head_dim)[:num_positive]
|
||||
dim_pos[i, j, positive_indices] = head_dim
|
||||
|
||||
# Create paged KT cache
|
||||
kt_cache_tensor, kt_cache_block_offsets, max_kt_blocks_per_seq = create_kt_cache_from_k(
|
||||
k=k,
|
||||
kv_lens=kv_lens_tensor,
|
||||
kt_page_size=kt_page_size,
|
||||
tokens_per_block=tokens_per_block,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# Calculate output offsets
|
||||
max_kt_tokens = max((kv_len + kt_page_size - 1) // kt_page_size for kv_len in kv_lens)
|
||||
total_kt_tokens = num_gen_tokens * max_kt_tokens
|
||||
output_offsets = (
|
||||
torch.arange(0, num_gen_tokens + 1, device=device, dtype=torch.int32) * max_kt_tokens
|
||||
)
|
||||
|
||||
triton_scores = triton_rocket_paged_kt_cache_bmm(
|
||||
q=q,
|
||||
kt_cache_tensor=kt_cache_tensor,
|
||||
kt_cache_block_offsets=kt_cache_block_offsets,
|
||||
dim_pos=dim_pos,
|
||||
kv_lens=kv_lens_tensor,
|
||||
output_offsets=output_offsets,
|
||||
kt_page_size=kt_page_size,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_kt_blocks_per_seq=max_kt_blocks_per_seq,
|
||||
total_kt_tokens=total_kt_tokens,
|
||||
sm_scale=None,
|
||||
)
|
||||
|
||||
reference_scores = pytorch_reference_paged_kt_cache_bmm(
|
||||
q=q,
|
||||
k=k,
|
||||
dim_pos=dim_pos,
|
||||
kv_lens=kv_lens_tensor,
|
||||
kt_page_size=kt_page_size,
|
||||
sm_scale=None,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
# Only compare non-zero entries (valid kt tokens)
|
||||
mask = torch.zeros_like(triton_scores, dtype=torch.bool)
|
||||
for batch_idx in range(num_gen_tokens):
|
||||
kv_len = kv_lens[batch_idx]
|
||||
num_kt_tokens = (kv_len + kt_page_size - 1) // kt_page_size
|
||||
offset = batch_idx * max_kt_tokens
|
||||
mask[:, :, offset : offset + num_kt_tokens] = True
|
||||
|
||||
triton_valid = triton_scores[mask]
|
||||
reference_valid = reference_scores[mask]
|
||||
|
||||
max_diff = torch.max(torch.abs(triton_valid - reference_valid)).item()
|
||||
|
||||
print(f"Max absolute difference: {max_diff:.6f}")
|
||||
|
||||
assert max_diff < 0.05, f"Max difference {max_diff} exceeds threshold"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 80)
|
||||
print("Testing Triton BMM Kernel")
|
||||
print("=" * 80)
|
||||
|
||||
# Test triton_bmm
|
||||
print("\n--- Single batch, non-causal ---")
|
||||
test_triton_bmm(1, 32, [128], 8, 8, 128, False)
|
||||
|
||||
print("\n--- Single batch, causal ---")
|
||||
test_triton_bmm(1, 32, [128], 8, 8, 128, True)
|
||||
|
||||
print("\n--- Multiple batches, different k_len ---")
|
||||
test_triton_bmm(3, 32, [64, 128, 256], 8, 8, 128, False)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Testing Triton Rocket Paged KT Cache BMM Kernel")
|
||||
print("=" * 80)
|
||||
|
||||
# Test triton_rocket_paged_kt_cache_bmm
|
||||
print("\n--- Single batch ---")
|
||||
test_triton_rocket_paged_kt_cache_bmm(1, [128], 8, 4, 128, 3, 64)
|
||||
|
||||
print("\n--- Multiple batches, different kv_len ---")
|
||||
test_triton_rocket_paged_kt_cache_bmm(3, [100, 200, 150], 8, 4, 128, 3, 64)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("All tests passed!")
|
||||
print("=" * 80)
|
||||
228
tests/unittest/_torch/attention/sparse/test_triton_topk.py
Normal file
228
tests/unittest/_torch/attention/sparse/test_triton_topk.py
Normal file
@ -0,0 +1,228 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.sparse.kernel import triton_topk
|
||||
|
||||
|
||||
def pytorch_reference_topk(
|
||||
input_tensor: torch.Tensor,
|
||||
kv_offsets: torch.Tensor,
|
||||
kv_lens: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
input_tensor: Input scores [num_kv_heads, sum(kv_lens)]
|
||||
kv_offsets: KV offsets [batch_size + 1]
|
||||
kv_lens: KV lengths [batch_size]
|
||||
topk: TopK parameter
|
||||
|
||||
Returns:
|
||||
output_indices: Padded indices [num_kv_heads, batch_size, topk]
|
||||
"""
|
||||
num_kv_heads = input_tensor.shape[0]
|
||||
batch_size = len(kv_lens)
|
||||
device = input_tensor.device
|
||||
|
||||
# Compute max sequence length for padding
|
||||
max_seq_len = kv_lens.max().item()
|
||||
|
||||
# Ensure padding size >= topk
|
||||
pad_size = max(max_seq_len, topk)
|
||||
|
||||
# Create padded tensor [num_kv_heads, batch_size, pad_size]
|
||||
padded_tensor = torch.full(
|
||||
(num_kv_heads, batch_size, pad_size), float("-inf"), dtype=input_tensor.dtype, device=device
|
||||
)
|
||||
|
||||
# Fill in actual values
|
||||
for batch_idx in range(batch_size):
|
||||
start = kv_offsets[batch_idx].item()
|
||||
end = kv_offsets[batch_idx + 1].item()
|
||||
seq_len = kv_lens[batch_idx].item()
|
||||
|
||||
for head_idx in range(num_kv_heads):
|
||||
padded_tensor[head_idx, batch_idx, :seq_len] = input_tensor[head_idx, start:end]
|
||||
|
||||
# Perform batch topk: [num_kv_heads, batch_size, pad_size] -> [num_kv_heads, batch_size, topk]
|
||||
topk_values, topk_indices = torch.topk(
|
||||
padded_tensor,
|
||||
k=topk,
|
||||
dim=-1,
|
||||
largest=True,
|
||||
)
|
||||
|
||||
# Mask out invalid indices based on each batch's seq_len
|
||||
seq_lens_expanded = kv_lens.to(device).unsqueeze(0).unsqueeze(-1) # [1, batch_size, 1]
|
||||
# topk_indices: [num_kv_heads, batch_size, topk]
|
||||
mask = topk_indices >= seq_lens_expanded
|
||||
topk_indices.masked_fill_(mask, -1)
|
||||
|
||||
return topk_indices
|
||||
|
||||
|
||||
def triton_topk_wrapper(
|
||||
input_tensor: torch.Tensor,
|
||||
kv_offsets: torch.Tensor,
|
||||
kv_lens: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
input_tensor: Input scores [num_kv_heads, sum(kv_lens)]
|
||||
kv_offsets: KV offsets [batch_size + 1]
|
||||
kv_lens: KV lengths [batch_size]
|
||||
topk: TopK parameter
|
||||
|
||||
Returns:
|
||||
output_indices: Padded indices [num_kv_heads, batch_size, topk]
|
||||
"""
|
||||
num_kv_heads = input_tensor.shape[0]
|
||||
batch_size = len(kv_lens)
|
||||
device = input_tensor.device
|
||||
|
||||
sparse_lens = torch.tensor(
|
||||
[min(topk, seq_len.item()) for seq_len in kv_lens], dtype=torch.int32, device=device
|
||||
)
|
||||
sparse_offsets = torch.cat(
|
||||
[torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(sparse_lens, dim=0)]
|
||||
).to(device)
|
||||
total_sparse_indices = sparse_offsets[-1].item()
|
||||
|
||||
output_indices_flat = triton_topk(
|
||||
input_tensor, kv_offsets, sparse_offsets, total_sparse_indices, topk
|
||||
)
|
||||
|
||||
# Convert flat format to padded format [num_kv_heads, batch_size, topk]
|
||||
output_indices_padded = torch.full(
|
||||
(num_kv_heads, batch_size, topk), -1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
start = sparse_offsets[batch_idx].item()
|
||||
end = sparse_offsets[batch_idx + 1].item()
|
||||
actual_len = end - start
|
||||
|
||||
for head_idx in range(num_kv_heads):
|
||||
output_indices_padded[head_idx, batch_idx, :actual_len] = output_indices_flat[
|
||||
head_idx, start:end
|
||||
]
|
||||
|
||||
return output_indices_padded
|
||||
|
||||
|
||||
def compute_overlap_ratio(
|
||||
triton_indices: torch.Tensor,
|
||||
reference_indices: torch.Tensor,
|
||||
kv_lens: torch.Tensor,
|
||||
) -> float:
|
||||
"""
|
||||
Args:
|
||||
triton_indices: Triton topk results [num_kv_heads, batch_size, topk]
|
||||
reference_indices: Reference topk results [num_kv_heads, batch_size, topk]
|
||||
kv_lens: KV lengths [batch_size]
|
||||
|
||||
Returns:
|
||||
Average overlap ratio across all batches and heads
|
||||
"""
|
||||
num_kv_heads = triton_indices.shape[0]
|
||||
batch_size = triton_indices.shape[1]
|
||||
|
||||
overlap_ratios = []
|
||||
|
||||
# Compare batch by batch
|
||||
for batch_idx in range(batch_size):
|
||||
for head_idx in range(num_kv_heads):
|
||||
# Extract indices for this batch and head
|
||||
triton_batch = triton_indices[head_idx, batch_idx, :].cpu().tolist()
|
||||
reference_batch = reference_indices[head_idx, batch_idx, :].cpu().tolist()
|
||||
|
||||
# Filter out -1 (invalid/padding indices)
|
||||
triton_set = set([x for x in triton_batch if x >= 0])
|
||||
reference_set = set([x for x in reference_batch if x >= 0])
|
||||
|
||||
if len(reference_set) > 0:
|
||||
overlap = len(triton_set & reference_set)
|
||||
overlap_ratio = overlap / len(reference_set)
|
||||
overlap_ratios.append(overlap_ratio)
|
||||
|
||||
if len(overlap_ratios) == 0:
|
||||
return 1.0
|
||||
|
||||
return sum(overlap_ratios) / len(overlap_ratios)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,seq_lens,num_kv_heads,topk",
|
||||
[
|
||||
# Single batch, seq_len > topk
|
||||
(1, [3000], 1, 2048),
|
||||
# Single batch, seq_len < topk
|
||||
(1, [1000], 8, 2048),
|
||||
# Multiple batches, mixed seq_len (some < topk, some > topk)
|
||||
(6, [50, 150, 80, 300, 100, 256], 8, 128),
|
||||
(6, [1000, 2500, 1500, 3000, 1800, 4000], 8, 2048),
|
||||
],
|
||||
)
|
||||
def test_topk_kernel(batch_size, seq_lens, num_kv_heads, topk):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float32
|
||||
|
||||
kv_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
||||
|
||||
kv_offsets = torch.cat(
|
||||
[torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(kv_lens, dim=0)]
|
||||
).to(device)
|
||||
|
||||
total_tokens = kv_offsets[-1].item()
|
||||
|
||||
input_tensor = torch.randn((num_kv_heads, total_tokens), dtype=dtype, device=device)
|
||||
|
||||
triton_output = triton_topk_wrapper(
|
||||
input_tensor=input_tensor,
|
||||
kv_offsets=kv_offsets,
|
||||
kv_lens=kv_lens,
|
||||
topk=topk,
|
||||
)
|
||||
|
||||
reference_output = pytorch_reference_topk(
|
||||
input_tensor=input_tensor,
|
||||
kv_offsets=kv_offsets,
|
||||
kv_lens=kv_lens,
|
||||
topk=topk,
|
||||
)
|
||||
|
||||
overlap_ratio = compute_overlap_ratio(
|
||||
triton_output,
|
||||
reference_output,
|
||||
kv_lens,
|
||||
)
|
||||
|
||||
min_threshold = 0.99
|
||||
|
||||
print(f"overlap_ratio: {overlap_ratio}")
|
||||
|
||||
assert overlap_ratio >= min_threshold, (
|
||||
f"Overlap ratio {overlap_ratio:.4f} is too low (< {min_threshold})"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 80)
|
||||
print("Testing Triton TopK Kernel")
|
||||
print("=" * 80)
|
||||
|
||||
# Single batch tests
|
||||
print("\n--- Single batch, seq_len > topk ---")
|
||||
test_topk_kernel(1, [3000], 8, 2048)
|
||||
|
||||
print("\n--- Single batch, seq_len < topk ---")
|
||||
test_topk_kernel(1, [1000], 8, 2048)
|
||||
|
||||
print("\n--- Multiple batches, mixed seq_len ---")
|
||||
test_topk_kernel(6, [50, 150, 80, 300, 100, 256], 8, 128)
|
||||
test_topk_kernel(6, [1000, 2500, 1500, 3000, 1800, 4000], 8, 2048)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("All tests passed!")
|
||||
print("=" * 80)
|
||||
Loading…
Reference in New Issue
Block a user