mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None] [feat] Use triton kernels for RocketKV prediction module (#8682)
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
parent
cc4c980e03
commit
f07e9977c6
@ -20,7 +20,7 @@ namespace tensorrt_llm
|
|||||||
{
|
{
|
||||||
namespace kernels
|
namespace kernels
|
||||||
{
|
{
|
||||||
template <int THREADS_PER_BLOCK>
|
template <int THREADS_PER_BLOCK, int MAX_NUM_PAGES>
|
||||||
__global__ void gatherKvPageOffsetsKernel(
|
__global__ void gatherKvPageOffsetsKernel(
|
||||||
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
|
int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
|
||||||
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
|
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
|
||||||
@ -32,23 +32,33 @@ __global__ void gatherKvPageOffsetsKernel(
|
|||||||
// Each CUDA block processes one sequence from the batch for one head.
|
// Each CUDA block processes one sequence from the batch for one head.
|
||||||
int32_t const head_idx = blockIdx.x;
|
int32_t const head_idx = blockIdx.x;
|
||||||
int32_t const batch_idx = blockIdx.y;
|
int32_t const batch_idx = blockIdx.y;
|
||||||
|
int32_t const indices_block_size = sparse_params.sparse_attn_indices_block_size;
|
||||||
if (batch_idx >= batch_size)
|
if (batch_idx >= batch_size)
|
||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shared memory for reduction.
|
using BlockScan = cub::BlockScan<int32_t, THREADS_PER_BLOCK>;
|
||||||
__shared__ typename cub::BlockReduce<Pair, THREADS_PER_BLOCK>::TempStorage temp_storage;
|
using BlockReduce = cub::BlockReduce<Pair, THREADS_PER_BLOCK>;
|
||||||
|
|
||||||
|
__shared__ typename BlockScan::TempStorage temp_storage_scan;
|
||||||
|
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
|
||||||
|
|
||||||
|
__shared__ int32_t s_page_mask[MAX_NUM_PAGES];
|
||||||
|
__shared__ int32_t s_cu_page_mask[MAX_NUM_PAGES];
|
||||||
|
__shared__ int32_t s_scan_total; // Store total count from scan
|
||||||
|
|
||||||
// Get the range of sparse indices and the sequence length.
|
// Get the range of sparse indices and the sequence length.
|
||||||
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
|
int32_t const start_offset = sparse_params.sparse_attn_offsets[batch_idx];
|
||||||
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
|
int32_t const end_offset = sparse_params.sparse_attn_offsets[batch_idx + 1];
|
||||||
int32_t const total_pages = sparse_params.sparse_attn_offsets[batch_size];
|
int32_t const sparse_attn_indices_stride = sparse_params.sparse_attn_indices_stride;
|
||||||
int32_t const num_sparse_pages = end_offset - start_offset;
|
int32_t const num_sparse_indices = end_offset - start_offset;
|
||||||
int32_t const original_seq_len = seq_lengths[batch_idx];
|
int32_t const original_seq_len = seq_lengths[batch_idx];
|
||||||
|
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
|
||||||
|
int32_t const page_loops = (ori_valid_pages + MAX_NUM_PAGES - 1) / MAX_NUM_PAGES;
|
||||||
|
|
||||||
// Get global sparse index.
|
// Get global sparse index.
|
||||||
int32_t const sparse_idx_global = head_idx * total_pages + start_offset;
|
int32_t const sparse_idx_global = head_idx * sparse_attn_indices_stride + start_offset;
|
||||||
|
|
||||||
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
|
// Get the base memory offset. shape: [batch_size, 2, max_num_pages_per_seq]
|
||||||
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
|
size_t const src_base_offset = (size_t) batch_idx * 2 * max_num_pages_per_seq;
|
||||||
@ -58,56 +68,119 @@ __global__ void gatherKvPageOffsetsKernel(
|
|||||||
int32_t local_max_page_index = -1;
|
int32_t local_max_page_index = -1;
|
||||||
int32_t local_num_valid_pages = 0;
|
int32_t local_num_valid_pages = 0;
|
||||||
|
|
||||||
// Perform the gather operation.
|
int32_t src_page_idx_offset = 0;
|
||||||
for (int32_t i = threadIdx.x; i < num_sparse_pages; i += blockDim.x)
|
int32_t dst_page_idx_offset = 0;
|
||||||
|
for (int32_t loop_idx = 0; loop_idx < page_loops; loop_idx++)
|
||||||
{
|
{
|
||||||
// Get the source idx and offset.
|
src_page_idx_offset = loop_idx * MAX_NUM_PAGES;
|
||||||
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
|
int32_t loop_num_valid_pages = min(MAX_NUM_PAGES, ori_valid_pages - src_page_idx_offset);
|
||||||
if (src_idx < 0)
|
for (int32_t i = threadIdx.x; i < MAX_NUM_PAGES; i += blockDim.x)
|
||||||
{
|
{
|
||||||
continue;
|
s_page_mask[i] = 0;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int32_t i = threadIdx.x; i < num_sparse_indices; i += blockDim.x)
|
||||||
|
{
|
||||||
|
int32_t const src_idx = sparse_params.sparse_attn_indices[sparse_idx_global + i];
|
||||||
|
int32_t const src_idx_start = src_idx * indices_block_size;
|
||||||
|
int32_t const src_idx_end = min(src_idx_start + indices_block_size, original_seq_len);
|
||||||
|
for (int32_t j = src_idx_start; j < src_idx_end; j++)
|
||||||
|
{
|
||||||
|
int32_t const src_page_idx = j / tokens_per_page;
|
||||||
|
if (src_page_idx >= src_page_idx_offset && src_page_idx < src_page_idx_offset + loop_num_valid_pages)
|
||||||
|
{
|
||||||
|
atomicExch(&s_page_mask[src_page_idx - src_page_idx_offset], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Handle case when loop_num_valid_pages > blockDim.x by processing in chunks
|
||||||
|
int32_t scan_offset = 0;
|
||||||
|
int32_t const scan_chunks = (loop_num_valid_pages + blockDim.x - 1) / blockDim.x;
|
||||||
|
|
||||||
|
for (int32_t chunk_idx = 0; chunk_idx < scan_chunks; chunk_idx++)
|
||||||
|
{
|
||||||
|
int32_t const chunk_start = chunk_idx * blockDim.x;
|
||||||
|
int32_t const chunk_size = min((int32_t) blockDim.x, loop_num_valid_pages - chunk_start);
|
||||||
|
|
||||||
|
int32_t thread_data = (threadIdx.x < chunk_size) ? s_page_mask[chunk_start + threadIdx.x] : 0;
|
||||||
|
int32_t thread_output;
|
||||||
|
int32_t aggregate;
|
||||||
|
|
||||||
|
BlockScan(temp_storage_scan).ExclusiveSum(thread_data, thread_output, aggregate);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x < chunk_size)
|
||||||
|
{
|
||||||
|
s_cu_page_mask[chunk_start + threadIdx.x] = thread_output + scan_offset;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Update scan offset for next chunk
|
||||||
|
scan_offset += aggregate;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the local max page index.
|
if (threadIdx.x == 0)
|
||||||
local_max_page_index = max(local_max_page_index, src_idx);
|
{
|
||||||
local_num_valid_pages++;
|
s_scan_total = scan_offset;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// Get the source and destination offsets.
|
// Perform the gather operation.
|
||||||
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
|
for (int32_t i = threadIdx.x; i < loop_num_valid_pages; i += blockDim.x)
|
||||||
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
|
{
|
||||||
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + i;
|
// Skip if the page is not valid.
|
||||||
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + i;
|
if (s_page_mask[i] == 0)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Perform the gather operation: read from the sparse location and write to the dense location.
|
int32_t const src_idx = src_page_idx_offset + i;
|
||||||
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
|
int32_t const dst_idx = dst_page_idx_offset + s_cu_page_mask[i];
|
||||||
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
|
|
||||||
|
local_max_page_index = max(local_max_page_index, src_idx);
|
||||||
|
local_num_valid_pages++;
|
||||||
|
|
||||||
|
size_t const src_offset_dim0 = src_base_offset + 0 * max_num_pages_per_seq + src_idx;
|
||||||
|
size_t const src_offset_dim1 = src_base_offset + 1 * max_num_pages_per_seq + src_idx;
|
||||||
|
size_t const dst_offset_dim0 = dst_base_offset + 0 * max_num_pages_per_seq + dst_idx;
|
||||||
|
size_t const dst_offset_dim1 = dst_base_offset + 1 * max_num_pages_per_seq + dst_idx;
|
||||||
|
|
||||||
|
output_kv_page_offsets[dst_offset_dim0] = kv_page_offsets[src_offset_dim0];
|
||||||
|
output_kv_page_offsets[dst_offset_dim1] = kv_page_offsets[src_offset_dim1];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Update dst offset using the total count from scan
|
||||||
|
dst_page_idx_offset += s_scan_total;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reduce the local max page indices and number of valid pages.
|
// Reduce the local max page indices and number of valid pages.
|
||||||
Pair local_pair = {local_max_page_index, local_num_valid_pages};
|
Pair local_pair = {local_max_page_index, local_num_valid_pages};
|
||||||
Pair result = cub::BlockReduce<Pair, THREADS_PER_BLOCK>(temp_storage).Reduce(local_pair, PairReduceOp());
|
Pair result = BlockReduce(temp_storage_reduce).Reduce(local_pair, PairReduceOp());
|
||||||
|
|
||||||
// Update sequence length for this head and batch.
|
// Update sequence length for this head and batch.
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
{
|
{
|
||||||
int32_t const max_page_index = result.max_val;
|
int32_t const max_page_index = result.max_val;
|
||||||
int32_t const num_valid_pages = result.sum_val;
|
int32_t const num_valid_pages = result.sum_val;
|
||||||
int32_t const ori_valid_pages = (original_seq_len + tokens_per_page - 1) / tokens_per_page;
|
|
||||||
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
|
size_t const seq_len_offset = (size_t) head_idx * batch_size + batch_idx;
|
||||||
|
int32_t seq_len = 0;
|
||||||
if (num_valid_pages > 0)
|
if (num_valid_pages > 0)
|
||||||
{
|
{
|
||||||
int32_t seq_len = original_seq_len - (ori_valid_pages - num_valid_pages) * tokens_per_page;
|
if (max_page_index == ori_valid_pages - 1)
|
||||||
int32_t seq_len_remain = original_seq_len % tokens_per_page;
|
|
||||||
if (max_page_index != ori_valid_pages - 1 && seq_len_remain != 0)
|
|
||||||
{
|
{
|
||||||
seq_len += tokens_per_page - seq_len_remain;
|
seq_len = (num_valid_pages - 1) * tokens_per_page
|
||||||
|
+ (original_seq_len - (ori_valid_pages - 1) * tokens_per_page);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
seq_len = num_valid_pages * tokens_per_page;
|
||||||
}
|
}
|
||||||
output_seq_lengths[seq_len_offset] = seq_len;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
output_seq_lengths[seq_len_offset] = 0;
|
|
||||||
}
|
}
|
||||||
|
output_seq_lengths[seq_len_offset] = seq_len;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,11 +194,8 @@ void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, int32_t* output_
|
|||||||
dim3 grid(num_head_kv, batch_size, 1);
|
dim3 grid(num_head_kv, batch_size, 1);
|
||||||
// The block.
|
// The block.
|
||||||
dim3 block(256, 1, 1);
|
dim3 block(256, 1, 1);
|
||||||
// Shared memory size.
|
|
||||||
size_t smem_size = sizeof(Pair) * 256;
|
|
||||||
|
|
||||||
// Launch the kernel.
|
gatherKvPageOffsetsKernel<256, 512><<<grid, block, 0, stream>>>(output_kv_page_offsets, output_seq_lengths,
|
||||||
gatherKvPageOffsetsKernel<256><<<grid, block, smem_size, stream>>>(output_kv_page_offsets, output_seq_lengths,
|
|
||||||
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
|
kv_page_offsets, seq_lengths, sparse_params, batch_size, tokens_per_page, max_num_pages_per_seq);
|
||||||
}
|
}
|
||||||
} // namespace kernels
|
} // namespace kernels
|
||||||
|
|||||||
@ -35,6 +35,9 @@ struct SparseAttentionParams
|
|||||||
int32_t sparse_mla_topk{0}; // for DSA attention
|
int32_t sparse_mla_topk{0}; // for DSA attention
|
||||||
void* sparse_mla_kv_cache_pool{nullptr}; // for DSA attention
|
void* sparse_mla_kv_cache_pool{nullptr}; // for DSA attention
|
||||||
|
|
||||||
|
int32_t sparse_attn_indices_block_size{1};
|
||||||
|
int32_t sparse_attn_indices_stride{0};
|
||||||
|
|
||||||
std::string toString() const
|
std::string toString() const
|
||||||
{
|
{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
@ -43,7 +46,9 @@ struct SparseAttentionParams
|
|||||||
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
|
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
|
||||||
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl
|
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl
|
||||||
<< "sparse_mla_topk: " << this->sparse_mla_topk << std::endl
|
<< "sparse_mla_topk: " << this->sparse_mla_topk << std::endl
|
||||||
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl;
|
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl
|
||||||
|
<< "sparse_attn_indices_block_size: " << this->sparse_attn_indices_block_size << std::endl
|
||||||
|
<< "sparse_attn_indices_stride: " << this->sparse_attn_indices_stride << std::endl;
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -64,10 +64,11 @@ void initBindings(nb::module_& m)
|
|||||||
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
|
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
|
||||||
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_kv_indices") = std::nullopt,
|
nb::arg("spec_decoding_tensor_params"), nb::arg("sparse_kv_indices") = std::nullopt,
|
||||||
nb::arg("sparse_kv_offsets") = std::nullopt, nb::arg("sparse_attn_indices") = std::nullopt,
|
nb::arg("sparse_kv_offsets") = std::nullopt, nb::arg("sparse_attn_indices") = std::nullopt,
|
||||||
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_mla_topk") = std::nullopt,
|
nb::arg("sparse_attn_offsets") = std::nullopt, nb::arg("sparse_attn_indices_block_size"),
|
||||||
nb::arg("cu_q_seqlens") = std::nullopt, nb::arg("cu_kv_seqlens") = std::nullopt,
|
nb::arg("sparse_mla_topk") = std::nullopt, nb::arg("cu_q_seqlens") = std::nullopt,
|
||||||
nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt,
|
nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt,
|
||||||
nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt,
|
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
|
||||||
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());
|
nb::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
|
||||||
|
nb::call_guard<nb::gil_scoped_release>());
|
||||||
}
|
}
|
||||||
} // namespace tensorrt_llm::nanobind::thop
|
} // namespace tensorrt_llm::nanobind::thop
|
||||||
|
|||||||
@ -64,10 +64,11 @@ void initBindings(pybind11::module_& m)
|
|||||||
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
|
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
|
||||||
py::arg("spec_decoding_tensor_params"), py::arg("sparse_kv_indices") = std::nullopt,
|
py::arg("spec_decoding_tensor_params"), py::arg("sparse_kv_indices") = std::nullopt,
|
||||||
py::arg("sparse_kv_offsets") = std::nullopt, py::arg("sparse_attn_indices") = std::nullopt,
|
py::arg("sparse_kv_offsets") = std::nullopt, py::arg("sparse_attn_indices") = std::nullopt,
|
||||||
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_mla_topk") = std::nullopt,
|
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_attn_indices_block_size"),
|
||||||
py::arg("cu_q_seqlens") = std::nullopt, py::arg("cu_kv_seqlens") = std::nullopt,
|
py::arg("sparse_mla_topk") = std::nullopt, py::arg("cu_q_seqlens") = std::nullopt,
|
||||||
py::arg("fmha_scheduler_counter") = std::nullopt, py::arg("mla_bmm1_scale") = std::nullopt,
|
py::arg("cu_kv_seqlens") = std::nullopt, py::arg("fmha_scheduler_counter") = std::nullopt,
|
||||||
py::arg("mla_bmm2_scale") = std::nullopt, py::arg("quant_q_buffer") = std::nullopt,
|
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
|
||||||
"Multi-head attention operation", py::call_guard<py::gil_scoped_release>());
|
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
}
|
}
|
||||||
} // namespace tensorrt_llm::pybind::thop
|
} // namespace tensorrt_llm::pybind::thop
|
||||||
|
|||||||
@ -86,10 +86,11 @@ public:
|
|||||||
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||||
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
|
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
|
||||||
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
|
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
|
||||||
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
|
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
|
||||||
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
|
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
|
||||||
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
|
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
|
||||||
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const
|
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
|
||||||
|
std::optional<torch::Tensor> quant_q_buffer) const
|
||||||
= 0;
|
= 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -146,10 +147,11 @@ public:
|
|||||||
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||||
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
|
torch::optional<torch::Tensor> attention_sinks, torch::optional<torch::Tensor> sparse_kv_indices,
|
||||||
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
|
torch::optional<torch::Tensor> sparse_kv_offsets, torch::optional<torch::Tensor> sparse_attn_indices,
|
||||||
torch::optional<torch::Tensor> sparse_attn_offsets, int32_t const sparse_mla_topk,
|
torch::optional<torch::Tensor> sparse_attn_offsets, int64_t const sparse_attn_indices_block_size,
|
||||||
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
|
int32_t const sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
|
||||||
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
|
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
|
||||||
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer) const override
|
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
|
||||||
|
std::optional<torch::Tensor> quant_q_buffer) const override
|
||||||
{
|
{
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
|
auto stream = at::cuda::getCurrentCUDAStream(qkv_or_q.get_device());
|
||||||
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
|
T* attention_input = static_cast<T*>(qkv_or_q.slice(0, token_offset).data_ptr());
|
||||||
@ -395,6 +397,9 @@ public:
|
|||||||
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
|
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().data_ptr<int32_t>() : nullptr;
|
||||||
op.mRuntimeSparseAttentionParams.sparse_attn_offsets
|
op.mRuntimeSparseAttentionParams.sparse_attn_offsets
|
||||||
= sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
|
= sparse_attn_offsets.has_value() ? sparse_attn_offsets.value().data_ptr<int32_t>() : nullptr;
|
||||||
|
op.mRuntimeSparseAttentionParams.sparse_attn_indices_block_size = sparse_attn_indices_block_size;
|
||||||
|
op.mRuntimeSparseAttentionParams.sparse_attn_indices_stride
|
||||||
|
= sparse_attn_indices.has_value() ? sparse_attn_indices.value().size(-1) : 0;
|
||||||
if (op.isMLAEnabled() && op.mUseSparseAttention)
|
if (op.isMLAEnabled() && op.mUseSparseAttention)
|
||||||
{
|
{
|
||||||
op.mRuntimeSparseAttentionParams.sparse_mla_topk = sparse_mla_topk;
|
op.mRuntimeSparseAttentionParams.sparse_mla_topk = sparse_mla_topk;
|
||||||
@ -589,10 +594,10 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
|||||||
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||||
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
|
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
|
||||||
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
|
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
|
||||||
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
|
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> sparse_mla_topk,
|
||||||
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
|
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
|
||||||
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
|
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
|
||||||
std::optional<torch::Tensor> quant_q_buffer)
|
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer)
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
|
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
|
||||||
// Use these tensors to infer if the attention is using KV cache
|
// Use these tensors to infer if the attention is using KV cache
|
||||||
@ -847,8 +852,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
|||||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||||
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
||||||
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
|
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
|
||||||
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
|
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
|
||||||
quant_q_buffer);
|
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
|
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
|
||||||
@ -866,8 +871,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
|||||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||||
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
||||||
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
|
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
|
||||||
sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale,
|
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
|
||||||
quant_q_buffer);
|
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
|
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
|
||||||
|
|||||||
@ -63,9 +63,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
|||||||
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
|
||||||
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
|
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
|
||||||
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
|
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
|
||||||
std::optional<int64_t> sparse_mla_topk, std::optional<torch::Tensor> cu_q_seqlens,
|
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> sparse_mla_topk,
|
||||||
std::optional<torch::Tensor> cu_kv_seqlens, std::optional<torch::Tensor> fmha_scheduler_counter,
|
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
|
||||||
std::optional<torch::Tensor> mla_bmm1_scale, std::optional<torch::Tensor> mla_bmm2_scale,
|
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
|
||||||
std::optional<torch::Tensor> quant_q_buffer);
|
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer);
|
||||||
|
|
||||||
} // namespace torch_ext
|
} // namespace torch_ext
|
||||||
|
|||||||
@ -36,14 +36,16 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
|||||||
constexpr int num_head_kv = 4;
|
constexpr int num_head_kv = 4;
|
||||||
constexpr int max_num_pages_per_seq = 8;
|
constexpr int max_num_pages_per_seq = 8;
|
||||||
constexpr int tokens_per_page = 64;
|
constexpr int tokens_per_page = 64;
|
||||||
constexpr int total_sparse_pages = max_batch_size * max_num_pages_per_seq; // Total sparse pages across all batches
|
// Batch 0 has 8 sparse tokens, Batch 1 has 6 sparse tokens, total = 14
|
||||||
|
constexpr int total_sparse_tokens = 14;
|
||||||
|
|
||||||
// Create input buffers
|
// Create input buffers
|
||||||
auto kv_page_offsets
|
auto kv_page_offsets
|
||||||
= mBufferManager->gpu(ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
= mBufferManager->gpu(ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
||||||
auto seq_lengths = mBufferManager->gpu(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
|
auto seq_lengths = mBufferManager->gpu(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
|
||||||
|
// Shape: [num_head_kv, total_sparse_tokens] - flattened across all batches
|
||||||
auto sparse_indices
|
auto sparse_indices
|
||||||
= mBufferManager->gpu(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
|
= mBufferManager->gpu(ITensor::makeShape({num_head_kv, total_sparse_tokens}), nvinfer1::DataType::kINT32);
|
||||||
auto sparse_indices_offsets = mBufferManager->gpu(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
|
auto sparse_indices_offsets = mBufferManager->gpu(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
|
||||||
|
|
||||||
// Create output buffers
|
// Create output buffers
|
||||||
@ -57,7 +59,7 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
|||||||
ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
ITensor::makeShape({batch_size, 2, max_num_pages_per_seq}), nvinfer1::DataType::kINT32);
|
||||||
auto seq_lengths_host = mBufferManager->pinned(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
|
auto seq_lengths_host = mBufferManager->pinned(ITensor::makeShape({batch_size}), nvinfer1::DataType::kINT32);
|
||||||
auto sparse_indices_host
|
auto sparse_indices_host
|
||||||
= mBufferManager->pinned(ITensor::makeShape({total_sparse_pages, num_head_kv}), nvinfer1::DataType::kINT32);
|
= mBufferManager->pinned(ITensor::makeShape({num_head_kv, total_sparse_tokens}), nvinfer1::DataType::kINT32);
|
||||||
auto sparse_indices_offsets_host
|
auto sparse_indices_offsets_host
|
||||||
= mBufferManager->pinned(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
|
= mBufferManager->pinned(ITensor::makeShape({batch_size + 1}), nvinfer1::DataType::kINT32);
|
||||||
|
|
||||||
@ -81,27 +83,43 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize sequence lengths
|
// Initialize sequence lengths
|
||||||
seq_lengths_ptr[0] = 2 * tokens_per_page + 18; // 3 pages for batch 0
|
seq_lengths_ptr[0] = 2 * tokens_per_page + 18; // 3 pages (146 tokens) for batch 0
|
||||||
seq_lengths_ptr[1] = 3 * tokens_per_page + 3; // 4 pages for batch 1
|
seq_lengths_ptr[1] = 3 * tokens_per_page + 3; // 4 pages (195 tokens) for batch 1
|
||||||
|
|
||||||
// Initialize sparse indices with different patterns for different heads
|
// Initialize sparse indices with token-level indices (indices_block_size = 1)
|
||||||
// Shape: {total_sparse_pages, num_head_kv}
|
// Shape: [num_head_kv, total_sparse_tokens]
|
||||||
// Each head can have its own sparse pattern
|
// All heads have the same number of sparse tokens: 8 for batch 0, 6 for batch 1
|
||||||
int num_sparse_pages = 5;
|
// Memory layout: sparse_indices_ptr[head_idx * total_sparse_tokens + token_offset]
|
||||||
int sparse_page_indices[5][4] = {{1, 0, 0, 1}, {2, 1, 1, -1}, {-1, 2, -1, -1}, {0, 1, 2, 3}, {3, 3, 3, -1}};
|
std::vector<std::vector<int>> sparse_tokens_per_head
|
||||||
for (int page = 0; page < num_sparse_pages; ++page)
|
= {// Head 0: Batch 0 [10,20,70,75,90,95,100,105] -> pages [0,0,1,1,1,1,1,1] -> unique [0,1]
|
||||||
|
// Batch 1 [64,65,128,129,192,193] -> pages [1,1,2,2,3,3] -> unique [1,2,3]
|
||||||
|
{10, 20, 70, 75, 90, 95, 100, 105, 64, 65, 128, 129, 192, 193},
|
||||||
|
|
||||||
|
// Head 1: Batch 0 [5,6,65,66,130,131,135,140] -> pages [0,0,1,1,2,2,2,2] -> unique [0,1,2]
|
||||||
|
// Batch 1 [70,71,128,129,190,191] -> pages [1,1,2,2,2,2] -> unique [1,2]
|
||||||
|
{5, 6, 65, 66, 130, 131, 135, 140, 70, 71, 128, 129, 190, 191},
|
||||||
|
|
||||||
|
// Head 2: Batch 0 [20,21,80,81,85,86,90,91] -> pages [0,0,1,1,1,1,1,1] -> unique [0,1]
|
||||||
|
// Batch 1 [64,65,66,67,68,69] -> pages [1,1,1,1,1,1] -> unique [1]
|
||||||
|
{20, 21, 80, 81, 85, 86, 90, 91, 64, 65, 66, 67, 68, 69},
|
||||||
|
|
||||||
|
// Head 3: Batch 0 [70,71,72,73,74,75,76,77] -> pages [1,1,1,1,1,1,1,1] -> unique [1]
|
||||||
|
// Batch 1 [192,193,194,195,196,197] -> pages [3,3,3,3,3,3] -> unique [3]
|
||||||
|
{70, 71, 72, 73, 74, 75, 76, 77, 192, 193, 194, 195, 196, 197}};
|
||||||
|
|
||||||
|
// Fill sparse_indices_ptr using the defined data
|
||||||
|
for (int head = 0; head < num_head_kv; ++head)
|
||||||
{
|
{
|
||||||
for (int head = 0; head < num_head_kv; ++head)
|
for (int token_idx = 0; token_idx < total_sparse_tokens; ++token_idx)
|
||||||
{
|
{
|
||||||
int idx = head * num_sparse_pages + page;
|
sparse_indices_ptr[head * total_sparse_tokens + token_idx] = sparse_tokens_per_head[head][token_idx];
|
||||||
sparse_indices_ptr[idx] = sparse_page_indices[page][head];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize sparse indices offsets
|
// Initialize sparse indices offsets (these are per-batch offsets into the flattened array)
|
||||||
sparse_indices_offsets_ptr[0] = 0; // Start of batch 0
|
sparse_indices_offsets_ptr[0] = 0; // Start of batch 0
|
||||||
sparse_indices_offsets_ptr[1] = 3; // Start of batch 1 (3 sparse pages for batch 0)
|
sparse_indices_offsets_ptr[1] = 8; // Start of batch 1 (batch 0 has 8 sparse tokens)
|
||||||
sparse_indices_offsets_ptr[2] = 5; // End (3 sparse pages for batch 1)
|
sparse_indices_offsets_ptr[2] = 14; // End (batch 1 has 6 sparse tokens, total = 14)
|
||||||
|
|
||||||
// Copy data to GPU
|
// Copy data to GPU
|
||||||
mBufferManager->copy(*kv_page_offsets_host, *kv_page_offsets);
|
mBufferManager->copy(*kv_page_offsets_host, *kv_page_offsets);
|
||||||
@ -112,6 +130,8 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
|||||||
SparseAttentionParams sparse_params;
|
SparseAttentionParams sparse_params;
|
||||||
sparse_params.sparse_attn_indices = bufferCast<int32_t>(*sparse_indices);
|
sparse_params.sparse_attn_indices = bufferCast<int32_t>(*sparse_indices);
|
||||||
sparse_params.sparse_attn_offsets = bufferCast<int32_t>(*sparse_indices_offsets);
|
sparse_params.sparse_attn_offsets = bufferCast<int32_t>(*sparse_indices_offsets);
|
||||||
|
sparse_params.sparse_attn_indices_block_size = 1; // Token-level indexing
|
||||||
|
sparse_params.sparse_attn_indices_stride = total_sparse_tokens;
|
||||||
|
|
||||||
// Launch the kernel
|
// Launch the kernel
|
||||||
invokeGatherKvPageOffsets(bufferCast<int32_t>(*output_kv_page_offsets), bufferCast<int32_t>(*output_seq_lengths),
|
invokeGatherKvPageOffsets(bufferCast<int32_t>(*output_kv_page_offsets), bufferCast<int32_t>(*output_seq_lengths),
|
||||||
@ -136,21 +156,49 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
|||||||
auto output_kv_offsets_ptr = bufferCast<int32_t>(*output_kv_page_offsets_host);
|
auto output_kv_offsets_ptr = bufferCast<int32_t>(*output_kv_page_offsets_host);
|
||||||
auto output_seq_len_ptr = bufferCast<int>(*output_seq_lengths_host);
|
auto output_seq_len_ptr = bufferCast<int>(*output_seq_lengths_host);
|
||||||
|
|
||||||
// Verify sequence lengths for each head and batch
|
// Define expected results for each head and batch
|
||||||
int expected_seq_lens[4][2] = {
|
// Format: {num_pages, {page_indices...}, seq_len}
|
||||||
{tokens_per_page + 18, tokens_per_page + 3}, // Head 0: batch 0 has 2 pages, batch 1 has 0 pages
|
struct ExpectedResult
|
||||||
{2 * tokens_per_page + 18, tokens_per_page + 3}, // Head 1: batch 0 has 3 pages, batch 1 has 0 pages
|
{
|
||||||
{2 * tokens_per_page, tokens_per_page + 3}, // Head 2: batch 0 has 2 pages, batch 1 has 0 pages
|
int num_pages;
|
||||||
{tokens_per_page, 3} // Head 3: batch 0 has 2 pages, batch 1 has 0 pages
|
std::vector<int> page_indices;
|
||||||
|
int seq_len;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
ExpectedResult expected_results[4][2] = {// Head 0
|
||||||
|
{ // Batch 0: tokens on pages [0,1] -> 2 pages, seq_len = 2 * 64 = 128
|
||||||
|
{2, {0, 1}, 2 * tokens_per_page},
|
||||||
|
// Batch 1: tokens on pages [1,2,3] -> 3 pages, max_page=3 (last page)
|
||||||
|
// seq_len = 195 - (4-3)*64 = 131 (no padding needed, max_page is last page)
|
||||||
|
{3, {1, 2, 3}, 131}},
|
||||||
|
// Head 1
|
||||||
|
{// Batch 0: tokens on pages [0,1,2] -> 3 pages (all), seq_len = 146
|
||||||
|
{3, {0, 1, 2}, 2 * tokens_per_page + 18},
|
||||||
|
// Batch 1: tokens on pages [1,2] -> 2 pages, max_page=2 (not last page)
|
||||||
|
// seq_len = 195 - (4-2)*64 = 67, padding: 67 + (64-3) = 128
|
||||||
|
{2, {1, 2}, 2 * tokens_per_page}},
|
||||||
|
// Head 2
|
||||||
|
{// Batch 0: tokens on pages [0,1] -> 2 pages, seq_len = 128
|
||||||
|
{2, {0, 1}, 2 * tokens_per_page},
|
||||||
|
// Batch 1: tokens on page [1] -> 1 page, max_page=1 (not last page)
|
||||||
|
// seq_len = 195 - (4-1)*64 = 3, padding: 3 + (64-3) = 64
|
||||||
|
{1, {1}, tokens_per_page}},
|
||||||
|
// Head 3
|
||||||
|
{// Batch 0: tokens on page [1] -> 1 page, seq_len = 64
|
||||||
|
{1, {1}, tokens_per_page},
|
||||||
|
// Batch 1: tokens on page [3] -> 1 page, max_page=3 (last page)
|
||||||
|
// seq_len = 195 - (4-1)*64 = 3 (no padding needed, max_page is last page)
|
||||||
|
{1, {3}, 3}}};
|
||||||
|
|
||||||
|
// Verify sequence lengths for each head and batch
|
||||||
for (int h = 0; h < num_head_kv; ++h)
|
for (int h = 0; h < num_head_kv; ++h)
|
||||||
{
|
{
|
||||||
for (int b = 0; b < batch_size; ++b)
|
for (int b = 0; b < batch_size; ++b)
|
||||||
{
|
{
|
||||||
int seq_len_idx = h * batch_size + b;
|
int seq_len_idx = h * batch_size + b;
|
||||||
EXPECT_EQ(output_seq_len_ptr[seq_len_idx], expected_seq_lens[h][b])
|
EXPECT_EQ(output_seq_len_ptr[seq_len_idx], expected_results[h][b].seq_len)
|
||||||
<< "Sequence length mismatch at head=" << h << ", batch=" << b;
|
<< "Sequence length mismatch at head=" << h << ", batch=" << b
|
||||||
|
<< ", expected=" << expected_results[h][b].seq_len << ", got=" << output_seq_len_ptr[seq_len_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,20 +207,13 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
|||||||
{
|
{
|
||||||
for (int b = 0; b < batch_size; ++b)
|
for (int b = 0; b < batch_size; ++b)
|
||||||
{
|
{
|
||||||
int num_sparse_pages_batch = sparse_indices_offsets_ptr[b + 1] - sparse_indices_offsets_ptr[b];
|
auto const& expected = expected_results[h][b];
|
||||||
|
|
||||||
for (int d = 0; d < 2; ++d)
|
for (int d = 0; d < 2; ++d)
|
||||||
{
|
{
|
||||||
for (int p = 0; p < num_sparse_pages_batch; ++p)
|
for (int p = 0; p < expected.num_pages; ++p)
|
||||||
{
|
{
|
||||||
// Calculate expected value (from the sparse index)
|
int src_page_idx = expected.page_indices[p];
|
||||||
int sparse_idx_global = sparse_indices_offsets_ptr[b] + p;
|
|
||||||
int src_page_idx
|
|
||||||
= sparse_indices_ptr[h * sparse_indices_offsets_ptr[batch_size] + sparse_idx_global];
|
|
||||||
|
|
||||||
if (src_page_idx == -1)
|
|
||||||
{
|
|
||||||
continue; // Skip invalid indices
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate output offset
|
// Calculate output offset
|
||||||
size_t output_offset = h * batch_size * 2 * max_num_pages_per_seq + b * 2 * max_num_pages_per_seq
|
size_t output_offset = h * batch_size * 2 * max_num_pages_per_seq + b * 2 * max_num_pages_per_seq
|
||||||
@ -181,7 +222,9 @@ TEST_F(sparseAttentionKernelsTest, GatherKvPageOffsetsKernelTest)
|
|||||||
int expected_value = 1000 + b * 100 + d * 10 + src_page_idx;
|
int expected_value = 1000 + b * 100 + d * 10 + src_page_idx;
|
||||||
|
|
||||||
EXPECT_EQ(output_kv_offsets_ptr[output_offset], expected_value)
|
EXPECT_EQ(output_kv_offsets_ptr[output_offset], expected_value)
|
||||||
<< "Mismatch at head=" << h << ", batch=" << b << ", dim=" << d << ", page=" << p;
|
<< "KV page offset mismatch at head=" << h << ", batch=" << b << ", dim=" << d << ", page=" << p
|
||||||
|
<< ", expected_page_idx=" << src_page_idx << ", expected_value=" << expected_value
|
||||||
|
<< ", got=" << output_kv_offsets_ptr[output_offset];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,10 +6,11 @@ This example demonstrates how to use sparse attention with TensorRT-LLM.
|
|||||||
|
|
||||||
Supported sparse attention algorithms:
|
Supported sparse attention algorithms:
|
||||||
- RocketKV
|
- RocketKV
|
||||||
|
- DSA
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
```bash
|
```bash
|
||||||
python llm_sparse_attention.py --algo RocketKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
|
python llm_sparse_attention.py --algo ROCKETKV --attention_backend TRTLLM --window_size 32 --kernel_size 63 --prompt_budget 2048
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
@ -70,7 +71,7 @@ def parse_arguments():
|
|||||||
help="The maximum chunk size for the indexer.")
|
help="The maximum chunk size for the indexer.")
|
||||||
parser.add_argument("--max_seq_len",
|
parser.add_argument("--max_seq_len",
|
||||||
type=int,
|
type=int,
|
||||||
default=8192,
|
default=10240,
|
||||||
help="The maximum sequence length.")
|
help="The maximum sequence length.")
|
||||||
parser.add_argument("--max_batch_size",
|
parser.add_argument("--max_batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -83,7 +84,7 @@ def parse_arguments():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_num_tokens",
|
"--max_num_tokens",
|
||||||
type=int,
|
type=int,
|
||||||
default=8192,
|
default=81920,
|
||||||
help=
|
help=
|
||||||
"The maximum total tokens (context + generation) across all sequences in a batch."
|
"The maximum total tokens (context + generation) across all sequences in a batch."
|
||||||
)
|
)
|
||||||
@ -104,7 +105,7 @@ def parse_arguments():
|
|||||||
|
|
||||||
# KV cache
|
# KV cache
|
||||||
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
|
parser.add_argument('--kv_cache_dtype', type=str, default='auto')
|
||||||
parser.add_argument("--kv_cache_fraction", type=float, default=None)
|
parser.add_argument("--kv_cache_fraction", type=float, default=0.7)
|
||||||
parser.add_argument('--num_samples', type=int, default=10)
|
parser.add_argument('--num_samples', type=int, default=10)
|
||||||
|
|
||||||
# Runtime
|
# Runtime
|
||||||
|
|||||||
@ -356,7 +356,6 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
|
|||||||
max_seq_len=args.max_seq_len,
|
max_seq_len=args.max_seq_len,
|
||||||
max_num_tokens=args.max_num_tokens,
|
max_num_tokens=args.max_num_tokens,
|
||||||
cuda_graph_config=cuda_graph_config,
|
cuda_graph_config=cuda_graph_config,
|
||||||
torch_compile_config=None,
|
|
||||||
print_iter_log=args.print_iter_log,
|
print_iter_log=args.print_iter_log,
|
||||||
moe_config=MoeConfig(backend=args.moe_backend),
|
moe_config=MoeConfig(backend=args.moe_backend),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -28,7 +28,8 @@ from transformers import AutoTokenizer
|
|||||||
|
|
||||||
# Add tensorrt_llm imports
|
# Add tensorrt_llm imports
|
||||||
from tensorrt_llm import LLM, SamplingParams
|
from tensorrt_llm import LLM, SamplingParams
|
||||||
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
|
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
|
||||||
|
RocketSparseAttentionConfig)
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
|
|
||||||
# Chat templates mapping
|
# Chat templates mapping
|
||||||
@ -362,6 +363,10 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
|
|||||||
sparse_attention_config = None
|
sparse_attention_config = None
|
||||||
logger.info("Using standard attention")
|
logger.info("Using standard attention")
|
||||||
|
|
||||||
|
cuda_graph_config = CudaGraphConfig(
|
||||||
|
max_batch_size=args.max_batch_size
|
||||||
|
) if args.attention_backend == "TRTLLM" else None
|
||||||
|
|
||||||
# Initialize LLM
|
# Initialize LLM
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=args.model_path,
|
model=args.model_path,
|
||||||
@ -372,8 +377,7 @@ def initialize_llm(args: argparse.Namespace) -> Tuple[LLM, AutoTokenizer]:
|
|||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
max_seq_len=args.max_seq_len,
|
max_seq_len=args.max_seq_len,
|
||||||
max_num_tokens=args.max_num_tokens,
|
max_num_tokens=args.max_num_tokens,
|
||||||
cuda_graph_config=None,
|
cuda_graph_config=cuda_graph_config,
|
||||||
torch_compile_config=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize tokenizer
|
# Initialize tokenizer
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -23,7 +23,13 @@ from tensorrt_llm.bindings.internal.batch_manager import \
|
|||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||||
|
|
||||||
from .kernel import triton_index_gather, triton_update_kt_cache
|
from .kernel import (triton_bmm, triton_flatten_to_batch, triton_index_gather,
|
||||||
|
triton_rocket_batch_to_flatten,
|
||||||
|
triton_rocket_paged_kt_cache_bmm, triton_rocket_qk_split,
|
||||||
|
triton_rocket_reduce_scores,
|
||||||
|
triton_rocket_update_kt_cache_ctx,
|
||||||
|
triton_rocket_update_kt_cache_gen, triton_softmax,
|
||||||
|
triton_topk)
|
||||||
|
|
||||||
ModelConfig = tensorrt_llm.bindings.ModelConfig
|
ModelConfig = tensorrt_llm.bindings.ModelConfig
|
||||||
|
|
||||||
@ -35,8 +41,77 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
|||||||
if self.sparse_attention_config is None:
|
if self.sparse_attention_config is None:
|
||||||
raise ValueError("Sparse attention config is not set")
|
raise ValueError("Sparse attention config is not set")
|
||||||
self.prompt_budget = self.sparse_attention_config.prompt_budget
|
self.prompt_budget = self.sparse_attention_config.prompt_budget
|
||||||
|
self.window_size = self.sparse_attention_config.window_size
|
||||||
|
self.page_size = self.sparse_attention_config.page_size
|
||||||
|
self.topk = self.sparse_attention_config.topk
|
||||||
|
|
||||||
capture_graph = torch.cuda.is_current_stream_capturing()
|
capture_graph = torch.cuda.is_current_stream_capturing()
|
||||||
|
|
||||||
|
# Cumulative valid sequence lengths for query and key
|
||||||
|
self.q_cu_seqlens_cuda = self.get_empty(
|
||||||
|
self.cuda_graph_buffers,
|
||||||
|
(self.max_num_sequences + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
cache_name="q_cu_seqlens_cuda",
|
||||||
|
capture_graph=capture_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_cu_seqlens = torch.zeros_like(self.q_cu_seqlens_cuda,
|
||||||
|
device='cpu',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
self.k_cu_seqlens_cuda = self.get_empty(
|
||||||
|
self.cuda_graph_buffers,
|
||||||
|
(self.max_num_sequences + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
cache_name="k_cu_seqlens_cuda",
|
||||||
|
capture_graph=capture_graph,
|
||||||
|
)
|
||||||
|
self.k_cu_seqlens = torch.zeros_like(self.k_cu_seqlens_cuda,
|
||||||
|
device='cpu',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
# Context length of RocketKV key for each valid sequence
|
||||||
|
self.k_context_lens = torch.empty(
|
||||||
|
self.max_num_sequences,
|
||||||
|
device='cpu',
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cumulative context lengths for each sequence
|
||||||
|
self.context_cumsum_cuda = self.get_empty(
|
||||||
|
self.cuda_graph_buffers,
|
||||||
|
(self.max_num_sequences + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
cache_name="context_cumsum_cuda",
|
||||||
|
capture_graph=capture_graph,
|
||||||
|
)
|
||||||
|
self.context_cumsum = torch.zeros_like(self.context_cumsum_cuda,
|
||||||
|
device='cpu',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
# Sparse kv indices offsets for each sequence in context phase
|
||||||
|
self.sparse_offsets_ctx_cuda = self.get_empty(
|
||||||
|
self.cuda_graph_buffers,
|
||||||
|
(self.max_num_sequences + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
cache_name="sparse_offsets_ctx_cuda",
|
||||||
|
capture_graph=capture_graph,
|
||||||
|
)
|
||||||
|
self.sparse_offsets_ctx = torch.zeros_like(self.sparse_offsets_ctx_cuda,
|
||||||
|
device='cpu',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
# Valid sequence indices used in sparse kv indices prediction
|
||||||
|
self.valid_seq_indices_cuda = self.get_empty(
|
||||||
|
self.cuda_graph_buffers,
|
||||||
|
(self.max_num_sequences, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
cache_name="valid_seq_indices_cuda",
|
||||||
|
capture_graph=capture_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
# KT cache block offsets used in KT cache related kernels
|
||||||
self.kt_cache_block_offsets = self.get_empty(
|
self.kt_cache_block_offsets = self.get_empty(
|
||||||
self.cuda_graph_buffers,
|
self.cuda_graph_buffers,
|
||||||
[
|
[
|
||||||
@ -54,6 +129,41 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
|||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Number of KT tokens for each sequence
|
||||||
|
self.num_kt_tokens = torch.empty(
|
||||||
|
self.max_num_sequences,
|
||||||
|
device='cpu',
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cumulative KT lengths for each sequence
|
||||||
|
self.cum_kt_lens_cuda = self.get_empty(
|
||||||
|
self.cuda_graph_buffers,
|
||||||
|
(self.max_num_sequences + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
cache_name="cum_kt_lens_cuda",
|
||||||
|
capture_graph=capture_graph,
|
||||||
|
)
|
||||||
|
self.cum_kt_lens = torch.zeros_like(self.cum_kt_lens_cuda,
|
||||||
|
device='cpu',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
# Sparse attn indices offsets for each sequence in generation phase
|
||||||
|
self.sparse_offsets_gen_cuda = self.get_empty(
|
||||||
|
self.cuda_graph_buffers,
|
||||||
|
(self.max_num_sequences + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
cache_name="sparse_offsets_gen_cuda",
|
||||||
|
capture_graph=capture_graph,
|
||||||
|
)
|
||||||
|
self.sparse_offsets_gen = torch.zeros_like(self.sparse_offsets_gen_cuda,
|
||||||
|
device='cpu',
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
# Maximum number of KT tokens
|
||||||
|
self.max_kt_tokens = (self.max_seq_len + self.page_size -
|
||||||
|
1) // self.page_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kt_tokens_per_block(self) -> Optional[int]:
|
def kt_tokens_per_block(self) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
@ -100,81 +210,85 @@ class RocketTrtllmAttentionMetadata(TrtllmAttentionMetadata):
|
|||||||
self.host_kt_cache_block_offsets[:self.num_seqs],
|
self.host_kt_cache_block_offsets[:self.num_seqs],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
|
# -------------------------------- Context phase --------------------------------
|
||||||
|
self.context_cumsum[1:self.num_contexts + 1] = torch.cumsum(
|
||||||
|
self.prompt_lens_cpu[:self.num_contexts], dim=0)
|
||||||
|
self.context_cumsum_cuda[:self.num_contexts + 1].copy_(
|
||||||
|
self.context_cumsum[:self.num_contexts + 1], non_blocking=True)
|
||||||
|
|
||||||
@torch.compile(dynamic=True)
|
# We need to filter out sequences that are too short to skip sparse kv indices prediction
|
||||||
def convert_token_to_page_sparse_indices(
|
valid_mask = self.prompt_lens_cpu[:self.
|
||||||
sparse_attn_indices: torch.Tensor, sparse_attn_offsets: torch.Tensor,
|
num_contexts] >= self.prompt_budget
|
||||||
metadata: 'TrtllmAttentionMetadata'
|
valid_seq_indices = torch.where(valid_mask)[0]
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
invalid_seq_indices = torch.where(~valid_mask)[0]
|
||||||
"""
|
valid_batch_size = len(valid_seq_indices)
|
||||||
Convert token-based sparse attention indices to page-based sparse attention indices.
|
self.valid_seq_indices_cuda[:valid_batch_size].copy_(valid_seq_indices,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
Args:
|
# Only consider sequences that are long enough for sparse kv indices prediction in context phase
|
||||||
sparse_attn_indices: Token-based indices with shape [num_tokens, num_kv_heads]
|
self.k_context_lens[:valid_batch_size] = self.prompt_lens_cpu[
|
||||||
sparse_attn_offsets: Offsets with shape [batch_size+1] indicating token boundaries for each batch
|
valid_seq_indices] - self.window_size
|
||||||
metadata: Attention metadata containing tokens_per_block (page_size)
|
if valid_batch_size > 0:
|
||||||
|
# Maximum context length of RocketKV key for valid sequences for padding
|
||||||
|
self.max_rocket_k_ctx_len = self.k_context_lens[:
|
||||||
|
valid_batch_size].max(
|
||||||
|
).item()
|
||||||
|
else:
|
||||||
|
self.max_rocket_k_ctx_len = 0
|
||||||
|
|
||||||
Returns:
|
sparse_counts_ctx = torch.zeros(self.num_contexts,
|
||||||
Tuple of (page_indices, page_offsets):
|
dtype=torch.int32,
|
||||||
- page_indices: Page-based indices with shape [num_pages, num_kv_heads]
|
device='cpu')
|
||||||
- page_offsets: Updated offsets with shape [batch_size+1] indicating page boundaries for each batch
|
sparse_counts_ctx[valid_seq_indices] = self.prompt_budget
|
||||||
|
sparse_counts_ctx[invalid_seq_indices] = self.prompt_lens_cpu[
|
||||||
|
invalid_seq_indices]
|
||||||
|
|
||||||
Example:
|
self.sparse_offsets_ctx[1:self.num_contexts + 1] = torch.cumsum(
|
||||||
If sparse_attn_indices first dimension is [1, 30, 67] and page_size=32,
|
sparse_counts_ctx, dim=0)
|
||||||
the result will be [0, 2] (token 1 -> page 0, token 30 -> page 0, token 67 -> page 2)
|
self.sparse_offsets_ctx_cuda[:self.num_contexts + 1].copy_(
|
||||||
"""
|
self.sparse_offsets_ctx[:self.num_contexts + 1], non_blocking=True)
|
||||||
page_size = metadata.tokens_per_block
|
|
||||||
batch_size = sparse_attn_offsets.size(0) - 1
|
|
||||||
num_kv_heads = sparse_attn_indices.size(1)
|
|
||||||
|
|
||||||
# Convert token indices to page indices
|
self.q_cu_seqlens[:valid_batch_size + 1] = torch.arange(
|
||||||
page_indices = sparse_attn_indices // page_size
|
valid_batch_size + 1, device='cpu',
|
||||||
|
dtype=torch.int32) * self.window_size
|
||||||
|
self.q_cu_seqlens_cuda[:valid_batch_size + 1].copy_(
|
||||||
|
self.q_cu_seqlens[:valid_batch_size + 1], non_blocking=True)
|
||||||
|
|
||||||
# Process each batch and each kv_head separately to remove duplicates
|
self.k_cu_seqlens[1:valid_batch_size + 1] = torch.cumsum(
|
||||||
new_page_indices_list = []
|
self.k_context_lens[:valid_batch_size], dim=0)
|
||||||
new_offsets = torch.zeros_like(sparse_attn_offsets)
|
self.k_cu_seqlens_cuda[:valid_batch_size + 1].copy_(
|
||||||
|
self.k_cu_seqlens[:valid_batch_size + 1], non_blocking=True)
|
||||||
|
|
||||||
current_offset = 0
|
self.valid_batch_size = valid_batch_size
|
||||||
for batch_idx in range(batch_size):
|
self.total_sparse_ctx_indices = self.sparse_offsets_ctx[
|
||||||
start_idx = sparse_attn_offsets[batch_idx]
|
self.num_contexts].item()
|
||||||
end_idx = sparse_attn_offsets[batch_idx + 1]
|
|
||||||
|
|
||||||
if start_idx >= end_idx:
|
# -------------------------------- Generation phase --------------------------------
|
||||||
# Empty batch
|
self.num_kt_tokens[:self.num_generations] = (
|
||||||
new_offsets[batch_idx + 1] = current_offset
|
self.kv_lens[self.num_contexts:self.num_seqs] + self.page_size -
|
||||||
continue
|
1) // self.page_size
|
||||||
|
|
||||||
batch_page_indices = page_indices[
|
self.cum_kt_lens[1:self.num_generations + 1] = torch.cumsum(
|
||||||
start_idx:end_idx] # [num_tokens_in_batch, num_kv_heads]
|
self.num_kt_tokens[:self.num_generations], dim=0)
|
||||||
|
self.cum_kt_lens_cuda[:self.num_generations + 1].copy_(
|
||||||
|
self.cum_kt_lens[:self.num_generations + 1], non_blocking=True)
|
||||||
|
|
||||||
# For each kv_head, remove duplicates while preserving order
|
self.total_kt_tokens = self.num_generations * self.max_kt_tokens
|
||||||
batch_unique_pages = []
|
|
||||||
for head_idx in range(num_kv_heads):
|
|
||||||
head_pages = batch_page_indices[:, head_idx]
|
|
||||||
unique_pages = torch.unique(head_pages, sorted=False)
|
|
||||||
batch_unique_pages.append(unique_pages)
|
|
||||||
|
|
||||||
# Find the maximum length among all heads for this batch
|
topk_tensor = torch.tensor(self.topk, dtype=torch.int32)
|
||||||
max_len = max(pages.size(0) for pages in batch_unique_pages)
|
|
||||||
|
|
||||||
if max_len > 0:
|
# Some sequences may have less than topk KT tokens
|
||||||
batch_result = torch.full((max_len, num_kv_heads),
|
# We need to use the minimum of topk and the number of KT tokens
|
||||||
fill_value=-1,
|
sparse_counts_gen = torch.minimum(
|
||||||
dtype=page_indices.dtype,
|
topk_tensor, self.num_kt_tokens[:self.num_generations])
|
||||||
device=page_indices.device)
|
|
||||||
|
|
||||||
for head_idx in range(num_kv_heads):
|
self.sparse_offsets_gen[1:self.num_generations + 1] = torch.cumsum(
|
||||||
unique_pages = batch_unique_pages[head_idx]
|
sparse_counts_gen[:self.num_generations], dim=0)
|
||||||
batch_result[:unique_pages.size(0), head_idx] = unique_pages
|
self.sparse_offsets_gen_cuda[:self.num_generations + 1].copy_(
|
||||||
|
self.sparse_offsets_gen[:self.num_generations + 1],
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
new_page_indices_list.append(batch_result)
|
self.total_sparse_gen_indices = self.topk * self.num_generations
|
||||||
current_offset += max_len
|
|
||||||
|
|
||||||
new_offsets[batch_idx + 1] = current_offset
|
|
||||||
|
|
||||||
new_page_indices = torch.cat(new_page_indices_list, dim=0)
|
|
||||||
|
|
||||||
return new_page_indices, new_offsets
|
|
||||||
|
|
||||||
|
|
||||||
class RocketTrtllmAttention(TrtllmAttention):
|
class RocketTrtllmAttention(TrtllmAttention):
|
||||||
@ -213,85 +327,6 @@ class RocketTrtllmAttention(TrtllmAttention):
|
|||||||
self.kernel_size = sparse_attention_config.kernel_size
|
self.kernel_size = sparse_attention_config.kernel_size
|
||||||
self.page_size = sparse_attention_config.page_size
|
self.page_size = sparse_attention_config.page_size
|
||||||
|
|
||||||
def sparse_attn_predict(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: Optional[torch.Tensor],
|
|
||||||
metadata: TrtllmAttentionMetadata,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Predict sparse attention indices.
|
|
||||||
For RocketKV:
|
|
||||||
- Generation phase: predict RocketKV sparse attention indices
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- sparse_attn_indices: [total_selected_indices, num_kv_heads]
|
|
||||||
- sparse_attn_offsets: [batch_size + 1] with cumulative indices count
|
|
||||||
"""
|
|
||||||
if k is None:
|
|
||||||
q, k, _ = q.split([
|
|
||||||
self.num_heads * self.head_dim, self.num_kv_heads *
|
|
||||||
self.head_dim, self.num_kv_heads * self.head_dim
|
|
||||||
],
|
|
||||||
dim=-1)
|
|
||||||
|
|
||||||
num_contexts = metadata.num_contexts
|
|
||||||
num_generations = metadata.num_generations
|
|
||||||
seq_lens = metadata.seq_lens
|
|
||||||
seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens
|
|
||||||
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
|
|
||||||
|
|
||||||
sparse_attn_indices = []
|
|
||||||
sparse_attn_offsets = [0]
|
|
||||||
|
|
||||||
q_offset = 0
|
|
||||||
k_offset = 0
|
|
||||||
|
|
||||||
for i in range(num_contexts + num_generations):
|
|
||||||
seq_len = seq_lens[i].item()
|
|
||||||
seq_len_kv = seq_lens_kv[i].item()
|
|
||||||
|
|
||||||
if seq_len <= 0 or seq_len_kv <= 0:
|
|
||||||
assert False, "Invalid sequence length"
|
|
||||||
|
|
||||||
single_q = q[q_offset:q_offset + seq_len]
|
|
||||||
single_k = k[k_offset:k_offset + seq_len_kv]
|
|
||||||
|
|
||||||
single_q = single_q.view(1, seq_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
single_k = single_k.view(1, seq_len_kv, self.num_kv_heads,
|
|
||||||
self.head_dim)
|
|
||||||
|
|
||||||
past_seen_token = past_seen_tokens[i]
|
|
||||||
# Generation phase: RocketKV sparse attention indices
|
|
||||||
if i >= num_contexts:
|
|
||||||
_sparse_attn_indices = self._rocketkv_selection(
|
|
||||||
single_q, single_k, past_seen_token, metadata, i)
|
|
||||||
if _sparse_attn_indices is not None:
|
|
||||||
sparse_attn_indices.append(
|
|
||||||
_sparse_attn_indices.squeeze(0)) # [topk, num_kv_heads]
|
|
||||||
sparse_attn_offsets.append(sparse_attn_offsets[-1] +
|
|
||||||
_sparse_attn_indices.size(1))
|
|
||||||
else:
|
|
||||||
sparse_attn_offsets.append(sparse_attn_offsets[-1])
|
|
||||||
|
|
||||||
q_offset += seq_len
|
|
||||||
k_offset += seq_len_kv
|
|
||||||
|
|
||||||
if len(sparse_attn_indices) == 0:
|
|
||||||
sparse_attn_indices, sparse_attn_offsets = None, None
|
|
||||||
else:
|
|
||||||
sparse_attn_indices = torch.cat(sparse_attn_indices,
|
|
||||||
dim=0).to(torch.int32)
|
|
||||||
sparse_attn_offsets = torch.tensor(sparse_attn_offsets,
|
|
||||||
dtype=torch.int32).to(q.device)
|
|
||||||
sparse_attn_indices, sparse_attn_offsets = convert_token_to_page_sparse_indices(
|
|
||||||
sparse_attn_indices, sparse_attn_offsets, metadata)
|
|
||||||
sparse_attn_indices = sparse_attn_indices.transpose(0,
|
|
||||||
1).contiguous()
|
|
||||||
return sparse_attn_indices, sparse_attn_offsets
|
|
||||||
|
|
||||||
def sparse_kv_predict(
|
def sparse_kv_predict(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@ -300,229 +335,203 @@ class RocketTrtllmAttention(TrtllmAttention):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Predict sparse kv indices.
|
Predict sparse KV indices using optimized SnapKV algorithm.
|
||||||
|
|
||||||
For RocketKV:
|
Uses a single Triton kernel to compute attention scores between observation window
|
||||||
- Context phase: predict RocketKV sparse kv indices
|
and prefix tokens, then selects the most important prefix tokens directly.
|
||||||
|
|
||||||
Returns:
|
|
||||||
- flattened_indices: [total_selected_indices, num_kv_heads]
|
|
||||||
- batch_offsets: [batch_size + 1] with cumulative indices count
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
num_ctx_tokens = metadata.num_ctx_tokens
|
||||||
|
if num_ctx_tokens == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
if k is None:
|
if k is None:
|
||||||
q, k, _ = q.split([
|
qkv_input = q[:num_ctx_tokens]
|
||||||
self.num_heads * self.head_dim, self.num_kv_heads *
|
|
||||||
self.head_dim, self.num_kv_heads * self.head_dim
|
|
||||||
],
|
|
||||||
dim=-1)
|
|
||||||
|
|
||||||
num_contexts = metadata.num_contexts
|
|
||||||
num_generations = metadata.num_generations
|
|
||||||
seq_lens = metadata.seq_lens
|
|
||||||
seq_lens_kv = metadata.seq_lens_kv if metadata.seq_lens_kv is not None else seq_lens
|
|
||||||
past_seen_tokens = metadata.kv_cache_params.num_cached_tokens_per_seq
|
|
||||||
|
|
||||||
sparse_kv_indices = []
|
|
||||||
sparse_kv_offsets = [0]
|
|
||||||
|
|
||||||
q_offset = 0
|
|
||||||
k_offset = 0
|
|
||||||
|
|
||||||
for i in range(num_contexts + num_generations):
|
|
||||||
seq_len = seq_lens[i].item()
|
|
||||||
seq_len_kv = seq_lens_kv[i].item()
|
|
||||||
|
|
||||||
if seq_len <= 0 or seq_len_kv <= 0:
|
|
||||||
assert False, "Invalid sequence length"
|
|
||||||
|
|
||||||
single_q = q[q_offset:q_offset + seq_len]
|
|
||||||
single_k = k[k_offset:k_offset + seq_len_kv]
|
|
||||||
|
|
||||||
single_q = single_q.view(1, seq_len, self.num_heads,
|
|
||||||
self.head_dim).transpose(1, 2)
|
|
||||||
single_k = single_k.view(1, seq_len_kv, self.num_kv_heads,
|
|
||||||
self.head_dim)
|
|
||||||
|
|
||||||
past_seen_token = past_seen_tokens[i]
|
|
||||||
if i < num_contexts:
|
|
||||||
# Context phase: SnapKV sparse kv indices
|
|
||||||
_sparse_kv_indices = self._get_snapkv_indices(
|
|
||||||
single_q, single_k, past_seen_token, metadata, i)
|
|
||||||
if _sparse_kv_indices is not None:
|
|
||||||
sparse_kv_indices.append(
|
|
||||||
_sparse_kv_indices.squeeze(0)) # [budget, num_kv_heads]
|
|
||||||
sparse_kv_offsets.append(sparse_kv_offsets[-1] +
|
|
||||||
_sparse_kv_indices.size(1))
|
|
||||||
else:
|
|
||||||
sparse_kv_offsets.append(sparse_kv_offsets[-1])
|
|
||||||
|
|
||||||
q_offset += seq_len
|
|
||||||
k_offset += seq_len_kv
|
|
||||||
|
|
||||||
if len(sparse_kv_indices) == 0:
|
|
||||||
sparse_kv_indices, sparse_kv_offsets = None, None
|
|
||||||
else:
|
else:
|
||||||
sparse_kv_indices = torch.cat(sparse_kv_indices,
|
qkv_input = torch.cat([q, k], dim=1)
|
||||||
dim=0).to(torch.int32)
|
|
||||||
sparse_kv_indices = sparse_kv_indices.transpose(0, 1).contiguous()
|
|
||||||
sparse_kv_offsets = torch.tensor(sparse_kv_offsets,
|
|
||||||
dtype=torch.int32).to(q.device)
|
|
||||||
return sparse_kv_indices, sparse_kv_offsets
|
|
||||||
|
|
||||||
def _get_snapkv_indices(self, q: Tensor, k: Tensor, past_seen_token: int,
|
if metadata.valid_batch_size > 0:
|
||||||
metadata: RocketTrtllmAttentionMetadata,
|
q_window, k_context = triton_rocket_qk_split(
|
||||||
sample_idx: int) -> Optional[Tensor]:
|
qkv_input,
|
||||||
"""
|
metadata.prompt_lens_cuda,
|
||||||
Get SnapKV sparse kv indices from the input sequence for context phase.
|
metadata.context_cumsum_cuda,
|
||||||
The shape of output is (1, prompt_budget, num_kv_heads)
|
metadata.valid_seq_indices_cuda,
|
||||||
"""
|
metadata.k_cu_seqlens_cuda,
|
||||||
bsz = 1
|
self.num_heads,
|
||||||
seq_len = k.size(1) # k shape: (1, seq_len, num_kv_heads, head_dim)
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.window_size,
|
||||||
|
metadata.valid_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
# If the sequence length is less than the prompt budget, do not enable sparse kv cache
|
scores = triton_bmm(q_window,
|
||||||
if seq_len <= self.prompt_budget:
|
k_context,
|
||||||
return None
|
metadata.q_cu_seqlens_cuda,
|
||||||
|
metadata.k_cu_seqlens_cuda,
|
||||||
|
metadata.valid_batch_size,
|
||||||
|
causal=False)
|
||||||
|
|
||||||
# Use last window_size tokens as observation
|
scores = triton_softmax(scores, metadata.k_cu_seqlens_cuda,
|
||||||
# (1, num_heads, window_size, head_dim)
|
metadata.valid_batch_size)
|
||||||
q_obs = q[:, :, -self.window_size:]
|
|
||||||
# (1, num_kv_heads, seq_len, head_dim)
|
|
||||||
k_pre = repeat_kv(k.transpose(1, 2),
|
|
||||||
self.num_heads // self.num_kv_heads)
|
|
||||||
|
|
||||||
dist = (torch.arange(0, self.window_size, device=q.device)[:, None] -
|
# scores: [num_heads, window_size, total_k_tokens] -> [num_kv_heads, total_k_tokens]
|
||||||
torch.arange(0, seq_len, device=q.device)[None, :] + seq_len -
|
scores = scores.view(self.num_kv_heads,
|
||||||
self.window_size)
|
self.num_heads // self.num_kv_heads,
|
||||||
attention_mask = (dist >= 0)
|
self.window_size, -1).sum(dim=(1, 2))
|
||||||
|
|
||||||
score = torch.matmul(q_obs, k_pre.transpose(-1, -2)) / math.sqrt(
|
# Reshape scores to handle variable length sequences with padding using Triton
|
||||||
self.head_dim)
|
# scores: [num_kv_heads, total_k_tokens] -> [valid_batch_size, num_kv_heads, padding_size]
|
||||||
|
scores = triton_flatten_to_batch(scores, metadata.k_cu_seqlens_cuda,
|
||||||
|
metadata.valid_batch_size,
|
||||||
|
metadata.max_rocket_k_ctx_len)
|
||||||
|
|
||||||
score = torch.masked_fill(
|
scores = torch.nn.functional.max_pool1d(
|
||||||
score,
|
scores,
|
||||||
attention_mask.view(1, 1, self.window_size, seq_len) == False,
|
kernel_size=self.kernel_size,
|
||||||
torch.scalar_tensor(float("-inf"),
|
padding=self.kernel_size // 2,
|
||||||
device=score.device,
|
stride=1)
|
||||||
dtype=score.dtype))
|
|
||||||
|
|
||||||
score = torch.nn.functional.softmax(score, dim=-1)
|
selected_prefix_indices = scores.topk(
|
||||||
|
self.prompt_budget - self.window_size,
|
||||||
|
dim=-1).indices.sort().values.to(torch.int32)
|
||||||
|
else:
|
||||||
|
selected_prefix_indices = torch.empty(
|
||||||
|
(0, self.num_kv_heads, self.prompt_budget - self.window_size),
|
||||||
|
device=qkv_input.device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
score = torch.masked_fill(
|
sparse_kv_offsets = metadata.sparse_offsets_ctx_cuda[:metadata.
|
||||||
score,
|
num_contexts + 1]
|
||||||
attention_mask.view(1, 1, self.window_size, seq_len) == False,
|
|
||||||
torch.scalar_tensor(0, device=score.device, dtype=score.dtype))
|
|
||||||
|
|
||||||
score = score[:, :, -self.window_size:, :-self.window_size].sum(dim=-2)
|
# Flatten sparse indices
|
||||||
|
sparse_kv_indices = triton_rocket_batch_to_flatten(
|
||||||
|
selected_prefix_indices, metadata.prompt_lens_cuda,
|
||||||
|
metadata.valid_seq_indices_cuda, sparse_kv_offsets,
|
||||||
|
metadata.num_contexts, metadata.total_sparse_ctx_indices,
|
||||||
|
self.window_size, self.prompt_budget)
|
||||||
|
|
||||||
score = score.view(bsz, self.num_kv_heads,
|
|
||||||
self.num_heads // self.num_kv_heads, -1).sum(dim=2)
|
|
||||||
score = torch.nn.functional.max_pool1d(score,
|
|
||||||
kernel_size=self.kernel_size,
|
|
||||||
padding=self.kernel_size // 2,
|
|
||||||
stride=1)
|
|
||||||
|
|
||||||
# Select top important tokens from prefix
|
|
||||||
prefix_len = seq_len - self.window_size
|
|
||||||
selected_prefix_indices = score.topk(self.prompt_budget -
|
|
||||||
self.window_size,
|
|
||||||
dim=-1).indices.sort().values
|
|
||||||
|
|
||||||
# Combine selected prefix indices with window indices
|
|
||||||
window_indices = torch.arange(
|
|
||||||
prefix_len, seq_len,
|
|
||||||
device=k.device).unsqueeze(0).unsqueeze(0).expand(
|
|
||||||
bsz, self.num_kv_heads, -1)
|
|
||||||
selected_indices = torch.cat([selected_prefix_indices, window_indices],
|
|
||||||
dim=-1).transpose(1, 2)
|
|
||||||
|
|
||||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
|
||||||
k_snap = triton_index_gather(k, selected_indices)
|
|
||||||
# Update KT cache
|
# Update KT cache
|
||||||
kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(
|
kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(
|
||||||
self.layer_idx)
|
self.layer_idx)
|
||||||
k_snap_len = torch.clamp(
|
|
||||||
metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1],
|
triton_rocket_update_kt_cache_ctx(
|
||||||
max=self.prompt_budget).int()
|
qkv_input.contiguous(),
|
||||||
triton_update_kt_cache(
|
|
||||||
k_snap.squeeze(0).contiguous(),
|
|
||||||
kt_cache_tensor,
|
kt_cache_tensor,
|
||||||
metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1],
|
metadata.kt_cache_block_offsets[:metadata.num_contexts],
|
||||||
k_snap_len,
|
metadata.context_cumsum_cuda[:metadata.num_contexts + 1],
|
||||||
|
sparse_kv_indices,
|
||||||
|
sparse_kv_offsets,
|
||||||
|
self.num_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
metadata.kt_tokens_per_block,
|
metadata.kt_tokens_per_block,
|
||||||
metadata.kv_cache_manager.max_kt_blocks_per_seq,
|
metadata.kv_cache_manager.max_kt_blocks_per_seq,
|
||||||
update=False)
|
)
|
||||||
|
|
||||||
return selected_indices
|
# Reduce overhead of post processing
|
||||||
|
if metadata.valid_batch_size == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
def _rocketkv_selection(self, q: Tensor, k: Tensor, past_seen_token: int,
|
return sparse_kv_indices, sparse_kv_offsets
|
||||||
metadata: RocketTrtllmAttentionMetadata,
|
|
||||||
sample_idx: int) -> Tensor:
|
|
||||||
"""
|
|
||||||
Implement RocketKV's two-stage selection process for generation phase.
|
|
||||||
The shape of output is (1, topk, num_kv_heads)
|
|
||||||
"""
|
|
||||||
bsz = 1
|
|
||||||
q_len = q.size(2)
|
|
||||||
|
|
||||||
# Helper functions
|
@torch.compile(dynamic=True)
|
||||||
def _gather(t: Tensor, dim: int, i: Tensor) -> Tensor:
|
def preprocess_for_gen(
|
||||||
dim += (dim < 0) * t.ndim
|
self, q: torch.Tensor, k: Optional[torch.Tensor],
|
||||||
return t.gather(
|
metadata: RocketTrtllmAttentionMetadata
|
||||||
dim, i.expand(*t.shape[:dim], i.shape[dim], *t.shape[dim + 1:]))
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
if k is None:
|
||||||
|
qkv_input = q[metadata.num_ctx_tokens:]
|
||||||
|
q_hidden_size = self.num_heads * self.head_dim
|
||||||
|
k_hidden_size = self.num_kv_heads * self.head_dim
|
||||||
|
q = qkv_input[:, :q_hidden_size]
|
||||||
|
k = qkv_input[:, q_hidden_size:q_hidden_size + k_hidden_size]
|
||||||
|
else:
|
||||||
|
q = q[metadata.num_ctx_tokens:]
|
||||||
|
k = k[metadata.num_ctx_tokens:]
|
||||||
|
|
||||||
@torch.compile(disable=not torch.cuda.is_available())
|
q = q.view(-1, self.num_kv_heads, self.num_heads // self.num_kv_heads,
|
||||||
def _scaled_softmax(x: Tensor, divscale: Tensor | float,
|
self.head_dim)
|
||||||
dim: int) -> Tensor:
|
|
||||||
return torch.softmax(x / divscale, dim=dim)
|
q_abs = torch.abs(q)
|
||||||
|
q_mask = torch.zeros_like(q)
|
||||||
|
|
||||||
|
i1 = torch.topk(q_abs.mean(dim=2, keepdim=True), self.topr,
|
||||||
|
dim=-1).indices
|
||||||
|
|
||||||
|
q_mask.scatter_(-1, i1.expand_as(q[..., :self.topr]), 1)
|
||||||
|
|
||||||
|
q_valid = q * q_mask
|
||||||
|
|
||||||
|
dim_pos = torch.where(q_valid.sum(dim=2) > 0, self.head_dim,
|
||||||
|
0).to(torch.int32)
|
||||||
|
|
||||||
|
return q_valid, k, dim_pos
|
||||||
|
|
||||||
|
def sparse_attn_predict(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: Optional[torch.Tensor],
|
||||||
|
metadata: TrtllmAttentionMetadata,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
if metadata.num_generations == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
q, k, dim_pos = self.preprocess_for_gen(q, k, metadata)
|
||||||
|
|
||||||
# Get KT cache for key-token matching
|
|
||||||
kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(
|
kt_cache_tensor = metadata.kv_cache_manager.get_kt_buffers(
|
||||||
self.layer_idx)
|
self.layer_idx)
|
||||||
target_seq_len = past_seen_token + 1 # +1 for current token
|
|
||||||
|
|
||||||
# Update KT cache
|
# Update KT cache with new key values
|
||||||
kt_states = triton_update_kt_cache(
|
triton_rocket_update_kt_cache_gen(
|
||||||
k.squeeze(0).contiguous(), kt_cache_tensor,
|
k,
|
||||||
metadata.kt_cache_block_offsets[sample_idx:sample_idx + 1],
|
kt_cache_tensor,
|
||||||
metadata.kv_lens_cuda_runtime[sample_idx:sample_idx + 1],
|
metadata.kt_cache_block_offsets[metadata.num_contexts:],
|
||||||
self.page_size, metadata.kt_tokens_per_block,
|
metadata.kv_lens_cuda_runtime[metadata.num_contexts:],
|
||||||
metadata.kv_cache_manager.max_kt_blocks_per_seq)
|
metadata.page_size,
|
||||||
kt_states = kt_states.unsqueeze(0).permute(0, 2, 3, 1)
|
metadata.kt_tokens_per_block,
|
||||||
|
metadata.kv_cache_manager.max_kt_blocks_per_seq,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
# Reshape query for multi-head processing
|
# Perform BMM with updated cache
|
||||||
qi = q.view(bsz, self.num_kv_heads, self.num_heads // self.num_kv_heads,
|
scores = triton_rocket_paged_kt_cache_bmm(
|
||||||
q_len, self.head_dim)
|
q,
|
||||||
qi_abs = torch.abs(qi)
|
kt_cache_tensor,
|
||||||
|
metadata.kt_cache_block_offsets[metadata.num_contexts:],
|
||||||
|
dim_pos,
|
||||||
|
metadata.kv_lens_cuda_runtime[metadata.num_contexts:],
|
||||||
|
metadata.cum_kt_lens_cuda,
|
||||||
|
metadata.page_size,
|
||||||
|
metadata.kt_tokens_per_block,
|
||||||
|
metadata.kv_cache_manager.max_kt_blocks_per_seq,
|
||||||
|
metadata.total_kt_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
# Top-r selection on query features
|
scores = triton_softmax(scores, metadata.cum_kt_lens_cuda,
|
||||||
i1 = torch.topk(qi_abs.mean(dim=2, keepdim=True), self.topr,
|
metadata.num_generations)
|
||||||
dim=-1).indices
|
|
||||||
qi_hat = _gather(qi, -1, i1)
|
|
||||||
|
|
||||||
# Generate signed indices for key-token matching
|
# Mean over num_heads_per_kv for each batch separately
|
||||||
i1_sign = torch.where(
|
scores = triton_rocket_reduce_scores(
|
||||||
qi_hat.sum(dim=2, keepdim=True) > 0, i1 + self.head_dim,
|
scores,
|
||||||
i1).transpose(-1, -2)
|
metadata.cum_kt_lens_cuda,
|
||||||
|
metadata.num_generations,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.num_heads // self.num_kv_heads,
|
||||||
|
)
|
||||||
|
|
||||||
# Gather key tokens and compute attention scores
|
sparse_attn_offsets = metadata.sparse_offsets_gen_cuda[:metadata.
|
||||||
kt_hat = _gather(kt_states.unsqueeze(2), -2, i1_sign)
|
num_generations +
|
||||||
qk_hat = qi_hat @ kt_hat
|
1]
|
||||||
qk_hat = qk_hat.repeat_interleave(self.page_size,
|
|
||||||
dim=-1)[:, :, :, :, :target_seq_len]
|
|
||||||
scale = torch.sqrt(self.head_dim *
|
|
||||||
torch.abs(qi_hat).sum(dim=-1, keepdim=True) /
|
|
||||||
qi_abs.sum(dim=-1, keepdim=True))
|
|
||||||
|
|
||||||
# (1, num_kv_heads, num_heads, target_seq_len)
|
selected_indices = triton_topk(scores, metadata.cum_kt_lens_cuda,
|
||||||
s_hat = _scaled_softmax(qk_hat, scale, dim=-1)
|
sparse_attn_offsets,
|
||||||
|
metadata.total_sparse_gen_indices,
|
||||||
|
metadata.topk)
|
||||||
|
|
||||||
topk = min(self.topk, target_seq_len)
|
return selected_indices, sparse_attn_offsets
|
||||||
i2 = torch.topk(s_hat.mean(dim=2, keepdim=True), topk, dim=-1).indices
|
|
||||||
|
|
||||||
iKV = i2[:, :, 0, 0, :].transpose(1, 2) # (1, topk, num_kv_heads)
|
|
||||||
|
|
||||||
return iKV
|
|
||||||
|
|
||||||
|
|
||||||
class RocketVanillaAttentionMetadata(VanillaAttentionMetadata):
|
class RocketVanillaAttentionMetadata(VanillaAttentionMetadata):
|
||||||
@ -920,6 +929,7 @@ class RocketKVCacheManager(KVCacheManager):
|
|||||||
max_num_draft_tokens: int = 0,
|
max_num_draft_tokens: int = 0,
|
||||||
use_mrope: bool = False,
|
use_mrope: bool = False,
|
||||||
max_beam_width: int = 1,
|
max_beam_width: int = 1,
|
||||||
|
num_extra_decoding_steps: int = 0,
|
||||||
):
|
):
|
||||||
requests = super().add_dummy_requests(
|
requests = super().add_dummy_requests(
|
||||||
request_ids=request_ids,
|
request_ids=request_ids,
|
||||||
@ -929,6 +939,7 @@ class RocketKVCacheManager(KVCacheManager):
|
|||||||
max_num_draft_tokens=max_num_draft_tokens,
|
max_num_draft_tokens=max_num_draft_tokens,
|
||||||
use_mrope=use_mrope,
|
use_mrope=use_mrope,
|
||||||
max_beam_width=max_beam_width,
|
max_beam_width=max_beam_width,
|
||||||
|
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||||
)
|
)
|
||||||
if prepare_resource:
|
if prepare_resource:
|
||||||
for req in requests:
|
for req in requests:
|
||||||
|
|||||||
@ -197,6 +197,7 @@ class TrtllmAttentionWrapper:
|
|||||||
sparse_kv_offsets: Optional[torch.Tensor] = None,
|
sparse_kv_offsets: Optional[torch.Tensor] = None,
|
||||||
sparse_attn_indices: Optional[torch.Tensor] = None,
|
sparse_attn_indices: Optional[torch.Tensor] = None,
|
||||||
sparse_attn_offsets: Optional[torch.Tensor] = None,
|
sparse_attn_offsets: Optional[torch.Tensor] = None,
|
||||||
|
sparse_attn_indices_block_size: int = 1,
|
||||||
sparse_mla_topk: int = 0,
|
sparse_mla_topk: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -241,6 +242,7 @@ class TrtllmAttentionWrapper:
|
|||||||
sparse_kv_offsets (torch.Tensor): The batch offsets for the sparse KV indices, with shape of (num_contexts + 1) on GPU.
|
sparse_kv_offsets (torch.Tensor): The batch offsets for the sparse KV indices, with shape of (num_contexts + 1) on GPU.
|
||||||
sparse_attn_indices (torch.Tensor): The sparse indices for the attention layer, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
|
sparse_attn_indices (torch.Tensor): The sparse indices for the attention layer, with shape of (num_heads_kv, num_sparse_tokens) on GPU.
|
||||||
sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU.
|
sparse_attn_offsets (torch.Tensor): The batch offsets for the sparse attention indices, with shape of (num_generations + 1) on GPU.
|
||||||
|
sparse_attn_indices_block_size (int): The granularity of the sparse attention indices, used by block sparse attention.
|
||||||
sparse_mla_topk (int): The topk for the sparse MLA, used by DSA attention.
|
sparse_mla_topk (int): The topk for the sparse MLA, used by DSA attention.
|
||||||
"""
|
"""
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
@ -283,6 +285,7 @@ class TrtllmAttentionWrapper:
|
|||||||
self.sparse_kv_offsets = sparse_kv_offsets
|
self.sparse_kv_offsets = sparse_kv_offsets
|
||||||
self.sparse_attn_indices = sparse_attn_indices
|
self.sparse_attn_indices = sparse_attn_indices
|
||||||
self.sparse_attn_offsets = sparse_attn_offsets
|
self.sparse_attn_offsets = sparse_attn_offsets
|
||||||
|
self.sparse_attn_indices_block_size = sparse_attn_indices_block_size
|
||||||
self.sparse_mla_topk = sparse_mla_topk
|
self.sparse_mla_topk = sparse_mla_topk
|
||||||
if max_sequence_length > self.rope_params.max_positions:
|
if max_sequence_length > self.rope_params.max_positions:
|
||||||
self.rope_params.max_positions = max_sequence_length
|
self.rope_params.max_positions = max_sequence_length
|
||||||
@ -525,6 +528,7 @@ class TrtllmAttentionWrapper:
|
|||||||
self.sparse_kv_offsets,
|
self.sparse_kv_offsets,
|
||||||
self.sparse_attn_indices,
|
self.sparse_attn_indices,
|
||||||
self.sparse_attn_offsets,
|
self.sparse_attn_offsets,
|
||||||
|
self.sparse_attn_indices_block_size,
|
||||||
self.sparse_mla_topk,
|
self.sparse_mla_topk,
|
||||||
cu_q_seqlens,
|
cu_q_seqlens,
|
||||||
cu_kv_seqlens,
|
cu_kv_seqlens,
|
||||||
@ -1308,11 +1312,14 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None
|
sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets = None, None, None, None
|
||||||
|
sparse_attn_indices_block_size = 1
|
||||||
if self.sparse_attention_config is not None:
|
if self.sparse_attention_config is not None:
|
||||||
sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(
|
sparse_kv_indices, sparse_kv_offsets = self.sparse_kv_predict(
|
||||||
q, k, metadata, **kwargs)
|
q, k, metadata, **kwargs)
|
||||||
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
|
sparse_attn_indices, sparse_attn_offsets = self.sparse_attn_predict(
|
||||||
q, k, metadata, **kwargs)
|
q, k, metadata, **kwargs)
|
||||||
|
sparse_attn_indices_block_size = self.sparse_attention_config.get_indices_block_size(
|
||||||
|
)
|
||||||
|
|
||||||
self.wrapper.plan(
|
self.wrapper.plan(
|
||||||
layer_idx=self.get_local_layer_idx(metadata),
|
layer_idx=self.get_local_layer_idx(metadata),
|
||||||
@ -1366,8 +1373,10 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
|||||||
sparse_kv_offsets=sparse_kv_offsets,
|
sparse_kv_offsets=sparse_kv_offsets,
|
||||||
sparse_attn_indices=sparse_attn_indices,
|
sparse_attn_indices=sparse_attn_indices,
|
||||||
sparse_attn_offsets=sparse_attn_offsets,
|
sparse_attn_offsets=sparse_attn_offsets,
|
||||||
|
sparse_attn_indices_block_size=sparse_attn_indices_block_size,
|
||||||
sparse_mla_topk=metadata.sparse_mla_topk if hasattr(
|
sparse_mla_topk=metadata.sparse_mla_topk if hasattr(
|
||||||
metadata, 'sparse_mla_topk') else 0)
|
metadata, 'sparse_mla_topk') else 0,
|
||||||
|
)
|
||||||
out_dtype = None
|
out_dtype = None
|
||||||
if out_scale is not None:
|
if out_scale is not None:
|
||||||
if use_nvfp4_output:
|
if use_nvfp4_output:
|
||||||
@ -1589,18 +1598,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
|||||||
self.mla_params.v_head_dim,
|
self.mla_params.v_head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sparse_attn_predict(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: Optional[torch.Tensor],
|
|
||||||
metadata: TrtllmAttentionMetadata,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Predict sparse attn indices. It's implemented in the derived class.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def sparse_kv_predict(
|
def sparse_kv_predict(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@ -1613,6 +1610,18 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def sparse_attn_predict(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: Optional[torch.Tensor],
|
||||||
|
metadata: TrtllmAttentionMetadata,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Predict sparse attn indices. It's implemented in the derived class.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def mla_rope_generation(
|
def mla_rope_generation(
|
||||||
self,
|
self,
|
||||||
fused_q: torch.Tensor,
|
fused_q: torch.Tensor,
|
||||||
|
|||||||
@ -215,6 +215,9 @@ class BaseSparseAttentionConfig(StrictBaseModel):
|
|||||||
"""
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def get_indices_block_size(self) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
|
class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
|
||||||
"""
|
"""
|
||||||
@ -238,6 +241,9 @@ class RocketSparseAttentionConfig(BaseSparseAttentionConfig):
|
|||||||
def supports_backend(self, backend: str) -> bool:
|
def supports_backend(self, backend: str) -> bool:
|
||||||
return backend == "pytorch"
|
return backend == "pytorch"
|
||||||
|
|
||||||
|
def get_indices_block_size(self) -> int:
|
||||||
|
return self.page_size
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekSparseAttentionConfig(BaseSparseAttentionConfig):
|
class DeepSeekSparseAttentionConfig(BaseSparseAttentionConfig):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -2,10 +2,18 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from utils.llm_data import llm_models_root
|
from utils.llm_data import llm_models_root
|
||||||
|
|
||||||
|
import tensorrt_llm
|
||||||
from tensorrt_llm import LLM, SamplingParams
|
from tensorrt_llm import LLM, SamplingParams
|
||||||
from tensorrt_llm.llmapi import KvCacheConfig, RocketSparseAttentionConfig
|
from tensorrt_llm._torch.attention_backend.sparse.rocket import (
|
||||||
|
RocketKVCacheManager, RocketTrtllmAttention, RocketTrtllmAttentionMetadata,
|
||||||
|
RocketVanillaAttention, RocketVanillaAttentionMetadata)
|
||||||
|
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||||
|
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
|
||||||
|
RocketSparseAttentionConfig)
|
||||||
|
from tensorrt_llm.mapping import Mapping
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("backend", ["pytorch"])
|
@pytest.mark.parametrize("backend", ["pytorch"])
|
||||||
@ -25,6 +33,11 @@ def test_model(backend, model_name, attention_backend):
|
|||||||
prompt_budget=2048,
|
prompt_budget=2048,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_graph_config = CudaGraphConfig(
|
||||||
|
batch_sizes=[1, 2, 4, 8, 16],
|
||||||
|
enable_padding=True,
|
||||||
|
)
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model_dir,
|
model=model_dir,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
@ -32,10 +45,10 @@ def test_model(backend, model_name, attention_backend):
|
|||||||
attn_backend=attention_backend,
|
attn_backend=attention_backend,
|
||||||
sparse_attention_config=sparse_attention_config,
|
sparse_attention_config=sparse_attention_config,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
max_seq_len=8192,
|
max_seq_len=20480,
|
||||||
max_num_tokens=8192,
|
max_num_tokens=81920,
|
||||||
cuda_graph_config=
|
cuda_graph_config=None
|
||||||
None, # sparse attention does not support cuda graph now
|
if attention_backend == "VANILLA" else cuda_graph_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs, references = [], []
|
inputs, references = [], []
|
||||||
@ -75,6 +88,559 @@ def test_model(backend, model_name, attention_backend):
|
|||||||
assert acc >= 0.9, 'accuracy test of rocketkv sparse attention failed'
|
assert acc >= 0.9, 'accuracy test of rocketkv sparse attention failed'
|
||||||
|
|
||||||
|
|
||||||
|
def create_rocket_kv_cache_manager(num_layers, num_kv_heads, head_dim,
|
||||||
|
tokens_per_block, max_seq_len,
|
||||||
|
max_batch_size, dtype, sparse_attn_config):
|
||||||
|
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||||
|
num_blocks = 100
|
||||||
|
kv_cache_config = KvCacheConfig(max_tokens=num_blocks * tokens_per_block,
|
||||||
|
enable_block_reuse=False)
|
||||||
|
|
||||||
|
kv_cache_manager = RocketKVCacheManager(
|
||||||
|
kv_cache_config,
|
||||||
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
tokens_per_block=tokens_per_block,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
mapping=mapping,
|
||||||
|
dtype=dtype,
|
||||||
|
sparse_attn_config=sparse_attn_config,
|
||||||
|
)
|
||||||
|
return kv_cache_manager
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_metadata(seq_lens, num_contexts, past_seen_tokens, request_ids,
|
||||||
|
kv_cache_manager, sparse_attn_config, metadata_cls):
|
||||||
|
prompt_lens = []
|
||||||
|
for i, (seq_len, past_token) in enumerate(zip(seq_lens, past_seen_tokens)):
|
||||||
|
if i < num_contexts:
|
||||||
|
prompt_lens.append(seq_len)
|
||||||
|
else:
|
||||||
|
prompt_lens.append(past_token + seq_len)
|
||||||
|
|
||||||
|
metadata = metadata_cls(
|
||||||
|
seq_lens=torch.tensor(seq_lens, dtype=torch.int),
|
||||||
|
num_contexts=num_contexts,
|
||||||
|
kv_cache_params=KVCacheParams(
|
||||||
|
use_cache=True, num_cached_tokens_per_seq=past_seen_tokens),
|
||||||
|
max_num_requests=len(seq_lens),
|
||||||
|
max_num_sequences=len(seq_lens),
|
||||||
|
max_num_tokens=8192,
|
||||||
|
kv_cache_manager=kv_cache_manager,
|
||||||
|
request_ids=request_ids,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
sparse_attention_config=sparse_attn_config,
|
||||||
|
)
|
||||||
|
metadata.prepare()
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch_size,num_contexts",
|
||||||
|
[
|
||||||
|
(1, 1), # bs=1
|
||||||
|
(4, 4), # bs=2, context only (2 contexts)
|
||||||
|
(6, 3), # bs=6, mixed (3 contexts + 3 generations)
|
||||||
|
])
|
||||||
|
def test_sparse_kv_predict(batch_size, num_contexts):
|
||||||
|
"""
|
||||||
|
Test sparse_kv_predict against vanilla _get_snapkv_indices.
|
||||||
|
|
||||||
|
This test verifies that the batched implementation produces the same results
|
||||||
|
as the single-request implementation for SnapKV sparse attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Test configuration
|
||||||
|
num_heads = 32
|
||||||
|
num_kv_heads = 8
|
||||||
|
head_dim = 128
|
||||||
|
device = torch.device('cuda')
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
sparse_attn_config = RocketSparseAttentionConfig(
|
||||||
|
window_size=32,
|
||||||
|
kernel_size=3,
|
||||||
|
prompt_budget=256,
|
||||||
|
page_size=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create sequence lengths - mix short and long sequences in context phase
|
||||||
|
seq_lens = []
|
||||||
|
past_seen_tokens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
if i < num_contexts:
|
||||||
|
# Context phase: mix sequences shorter and longer than prompt_budget
|
||||||
|
if i % 2 == 1 and batch_size > 1:
|
||||||
|
# Short sequence: seq_len < prompt_budget
|
||||||
|
seq_lens.append(
|
||||||
|
torch.randint(sparse_attn_config.prompt_budget // 2,
|
||||||
|
sparse_attn_config.prompt_budget - 10,
|
||||||
|
(1, )).item())
|
||||||
|
else:
|
||||||
|
# Long sequence: seq_len > prompt_budget
|
||||||
|
seq_lens.append(
|
||||||
|
torch.randint(sparse_attn_config.prompt_budget,
|
||||||
|
sparse_attn_config.prompt_budget + 200,
|
||||||
|
(1, )).item())
|
||||||
|
past_seen_tokens.append(0)
|
||||||
|
else:
|
||||||
|
# Generation phase: single token
|
||||||
|
seq_lens.append(1)
|
||||||
|
past_seen_tokens.append(torch.randint(100, 200, (1, )).item())
|
||||||
|
|
||||||
|
request_ids = list(range(batch_size))
|
||||||
|
|
||||||
|
num_layers = 1
|
||||||
|
tokens_per_block = 64
|
||||||
|
max_seq_len = 4096
|
||||||
|
if dtype == torch.float16:
|
||||||
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid dtype")
|
||||||
|
|
||||||
|
vanilla_tokens_per_block = max_seq_len # Each sequence in one block
|
||||||
|
|
||||||
|
trtllm_kv_cache_manager = create_rocket_kv_cache_manager(
|
||||||
|
num_layers, num_kv_heads, head_dim, tokens_per_block, max_seq_len,
|
||||||
|
batch_size, kv_cache_dtype, sparse_attn_config)
|
||||||
|
vanilla_kv_cache_manager = create_rocket_kv_cache_manager(
|
||||||
|
num_layers, num_kv_heads, head_dim, vanilla_tokens_per_block,
|
||||||
|
max_seq_len, batch_size, kv_cache_dtype, sparse_attn_config)
|
||||||
|
|
||||||
|
# Add dummy requests to both cache managers
|
||||||
|
token_nums = [
|
||||||
|
seq_len + past_token
|
||||||
|
for seq_len, past_token in zip(seq_lens, past_seen_tokens)
|
||||||
|
]
|
||||||
|
trtllm_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||||
|
vanilla_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||||
|
|
||||||
|
trtllm_attn = RocketTrtllmAttention(
|
||||||
|
layer_idx=0,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
sparse_attention_config=sparse_attn_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
vanilla_attn = RocketVanillaAttention(
|
||||||
|
layer_idx=0,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
sparse_attention_config=sparse_attn_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
trtllm_metadata = create_test_metadata(seq_lens, num_contexts,
|
||||||
|
past_seen_tokens, request_ids,
|
||||||
|
trtllm_kv_cache_manager,
|
||||||
|
sparse_attn_config,
|
||||||
|
RocketTrtllmAttentionMetadata)
|
||||||
|
vanilla_metadata = create_test_metadata(seq_lens, num_contexts,
|
||||||
|
past_seen_tokens, request_ids,
|
||||||
|
vanilla_kv_cache_manager,
|
||||||
|
sparse_attn_config,
|
||||||
|
RocketVanillaAttentionMetadata)
|
||||||
|
|
||||||
|
total_tokens = sum(seq_lens)
|
||||||
|
qkv = torch.randn(total_tokens, (num_heads + 2 * num_kv_heads) * head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
trtllm_sparse_kv_indices, trtllm_sparse_kv_offsets = trtllm_attn.sparse_kv_predict(
|
||||||
|
qkv, None, trtllm_metadata)
|
||||||
|
|
||||||
|
vanilla_sparse_kv_indices_list = []
|
||||||
|
offset = 0
|
||||||
|
for i in range(num_contexts):
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
single_qkv = qkv[offset:offset + seq_len]
|
||||||
|
q, k, _ = single_qkv.split([
|
||||||
|
num_heads * head_dim, num_kv_heads * head_dim,
|
||||||
|
num_kv_heads * head_dim
|
||||||
|
],
|
||||||
|
dim=-1)
|
||||||
|
q = q.view(1, seq_len, num_heads, head_dim).transpose(1, 2)
|
||||||
|
k = k.view(1, seq_len, num_kv_heads, head_dim)
|
||||||
|
|
||||||
|
if seq_len <= sparse_attn_config.prompt_budget:
|
||||||
|
# Short sequences: vanilla returns None, but trtllm returns [0, 1, ..., seq_len-1]
|
||||||
|
# Generate expected indices for comparison
|
||||||
|
short_indices = torch.arange(seq_len,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32).unsqueeze(0).expand(
|
||||||
|
num_kv_heads, -1)
|
||||||
|
vanilla_sparse_kv_indices_list.append(short_indices)
|
||||||
|
else:
|
||||||
|
vanilla_indices = vanilla_attn._get_snapkv_indices(q, k, i)
|
||||||
|
if vanilla_indices is not None:
|
||||||
|
vanilla_indices = vanilla_indices.squeeze(0).transpose(
|
||||||
|
0, 1).contiguous()
|
||||||
|
vanilla_sparse_kv_indices_list.append(vanilla_indices)
|
||||||
|
|
||||||
|
offset += seq_len
|
||||||
|
|
||||||
|
if len(vanilla_sparse_kv_indices_list) > 0:
|
||||||
|
vanilla_sparse_kv_indices = torch.cat(vanilla_sparse_kv_indices_list,
|
||||||
|
dim=-1).contiguous()
|
||||||
|
else:
|
||||||
|
vanilla_sparse_kv_indices = None
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
if trtllm_sparse_kv_indices is not None:
|
||||||
|
assert vanilla_sparse_kv_indices is not None, "Vanilla should also produce indices"
|
||||||
|
|
||||||
|
assert trtllm_sparse_kv_indices.shape == vanilla_sparse_kv_indices.shape, \
|
||||||
|
f"Shape mismatch: {trtllm_sparse_kv_indices.shape} vs {vanilla_sparse_kv_indices.shape}"
|
||||||
|
|
||||||
|
# Check indices overlap per batch and per head
|
||||||
|
num_kv_heads = trtllm_sparse_kv_indices.shape[0]
|
||||||
|
|
||||||
|
# trtllm_sparse_kv_offsets tells where each batch's indices start/end
|
||||||
|
trtllm_offsets = trtllm_sparse_kv_offsets.cpu().tolist()
|
||||||
|
|
||||||
|
overlap_ratios = []
|
||||||
|
batch_overlap_details = []
|
||||||
|
|
||||||
|
for batch_idx in range(num_contexts):
|
||||||
|
start_idx = trtllm_offsets[batch_idx]
|
||||||
|
end_idx = trtllm_offsets[batch_idx + 1]
|
||||||
|
end_idx - start_idx
|
||||||
|
|
||||||
|
batch_overlaps = []
|
||||||
|
for head_idx in range(num_kv_heads):
|
||||||
|
trtllm_batch = trtllm_sparse_kv_indices[
|
||||||
|
head_idx, start_idx:end_idx].cpu().tolist()
|
||||||
|
vanilla_batch = vanilla_sparse_kv_indices[
|
||||||
|
head_idx, start_idx:end_idx].cpu().tolist()
|
||||||
|
|
||||||
|
trtllm_set = set(trtllm_batch)
|
||||||
|
vanilla_set = set(vanilla_batch)
|
||||||
|
|
||||||
|
# Calculate overlap
|
||||||
|
overlap = len(vanilla_set & trtllm_set)
|
||||||
|
overlap_ratio = overlap / len(vanilla_set) if len(
|
||||||
|
vanilla_set) > 0 else 1.0
|
||||||
|
batch_overlaps.append(overlap_ratio)
|
||||||
|
overlap_ratios.append(overlap_ratio)
|
||||||
|
|
||||||
|
avg_batch_overlap = sum(batch_overlaps) / len(batch_overlaps)
|
||||||
|
batch_overlap_details.append(
|
||||||
|
f"Batch {batch_idx}: {avg_batch_overlap:.4f}")
|
||||||
|
|
||||||
|
avg_overlap_ratio = sum(overlap_ratios) / len(overlap_ratios)
|
||||||
|
print(f"Average overlap ratio: {avg_overlap_ratio:.4f}")
|
||||||
|
print(f"Per-batch average: {batch_overlap_details}")
|
||||||
|
|
||||||
|
assert avg_overlap_ratio >= 0.98, \
|
||||||
|
f"Indices overlap ratio {avg_overlap_ratio:.4f} is too low (< 0.98)"
|
||||||
|
else:
|
||||||
|
assert vanilla_sparse_kv_indices is None, "Both should return None when no sparse attention is needed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch_size,num_contexts",
|
||||||
|
[
|
||||||
|
(1, 0), # bs=1, generation only (1 generation)
|
||||||
|
(2, 0), # bs=2, generation only (2 generations)
|
||||||
|
(3, 0), # bs=3, generation only (3 generations)
|
||||||
|
(5, 3), # bs=5, mixed (3 contexts + 2 generations)
|
||||||
|
(6, 2), # bs=6, mixed (2 ctx + 4 gen)
|
||||||
|
])
|
||||||
|
def test_sparse_attn_predict(batch_size, num_contexts):
|
||||||
|
"""Test sparse_attn_predict against vanilla _rocketkv_selection."""
|
||||||
|
num_generations = batch_size - num_contexts
|
||||||
|
|
||||||
|
# Test configuration
|
||||||
|
num_heads = 32
|
||||||
|
num_kv_heads = 8
|
||||||
|
head_dim = 128
|
||||||
|
device = torch.device('cuda')
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
sparse_attn_config_vanilla = RocketSparseAttentionConfig(
|
||||||
|
window_size=32,
|
||||||
|
kernel_size=3,
|
||||||
|
prompt_budget=256,
|
||||||
|
page_size=3,
|
||||||
|
topk=128,
|
||||||
|
topr=96,
|
||||||
|
)
|
||||||
|
sparse_attn_config_trtllm = RocketSparseAttentionConfig(
|
||||||
|
window_size=32,
|
||||||
|
kernel_size=3,
|
||||||
|
prompt_budget=256,
|
||||||
|
page_size=3,
|
||||||
|
topk=43,
|
||||||
|
topr=96,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create sequence lengths
|
||||||
|
seq_lens = []
|
||||||
|
past_seen_tokens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
if i < num_contexts:
|
||||||
|
# Context phase: longer sequences
|
||||||
|
seq_lens.append(torch.randint(300, 400, (1, )).item())
|
||||||
|
past_seen_tokens.append(0)
|
||||||
|
else:
|
||||||
|
# Generation phase: single token
|
||||||
|
seq_lens.append(1)
|
||||||
|
# 128 is the minimum number of tokens for shape alignment
|
||||||
|
past_seen_tokens.append(torch.randint(128, 300, (1, )).item())
|
||||||
|
|
||||||
|
request_ids = list(range(batch_size))
|
||||||
|
|
||||||
|
num_layers = 1
|
||||||
|
tokens_per_block = 64
|
||||||
|
max_seq_len = 4096
|
||||||
|
if dtype == torch.float16:
|
||||||
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid dtype")
|
||||||
|
|
||||||
|
vanilla_tokens_per_block = max_seq_len # Each sequence in one block
|
||||||
|
|
||||||
|
trtllm_kv_cache_manager = create_rocket_kv_cache_manager(
|
||||||
|
num_layers, num_kv_heads, head_dim, tokens_per_block, max_seq_len,
|
||||||
|
batch_size, kv_cache_dtype, sparse_attn_config_trtllm)
|
||||||
|
vanilla_kv_cache_manager = create_rocket_kv_cache_manager(
|
||||||
|
num_layers, num_kv_heads, head_dim, vanilla_tokens_per_block,
|
||||||
|
max_seq_len, batch_size, kv_cache_dtype, sparse_attn_config_vanilla)
|
||||||
|
|
||||||
|
# Add dummy requests to both cache managers
|
||||||
|
token_nums = [
|
||||||
|
seq_len + past_token
|
||||||
|
for seq_len, past_token in zip(seq_lens, past_seen_tokens)
|
||||||
|
]
|
||||||
|
trtllm_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||||
|
vanilla_kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||||
|
|
||||||
|
trtllm_attn = RocketTrtllmAttention(
|
||||||
|
layer_idx=0,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
sparse_attention_config=sparse_attn_config_trtllm,
|
||||||
|
)
|
||||||
|
|
||||||
|
vanilla_attn = RocketVanillaAttention(
|
||||||
|
layer_idx=0,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
sparse_attention_config=sparse_attn_config_vanilla,
|
||||||
|
)
|
||||||
|
|
||||||
|
trtllm_metadata = create_test_metadata(seq_lens, num_contexts,
|
||||||
|
past_seen_tokens, request_ids,
|
||||||
|
trtllm_kv_cache_manager,
|
||||||
|
sparse_attn_config_trtllm,
|
||||||
|
RocketTrtllmAttentionMetadata)
|
||||||
|
vanilla_metadata = create_test_metadata(seq_lens, num_contexts,
|
||||||
|
past_seen_tokens, request_ids,
|
||||||
|
vanilla_kv_cache_manager,
|
||||||
|
sparse_attn_config_vanilla,
|
||||||
|
RocketVanillaAttentionMetadata)
|
||||||
|
|
||||||
|
total_tokens = sum(seq_lens)
|
||||||
|
qkv = torch.randn(total_tokens, (num_heads + 2 * num_kv_heads) * head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
trtllm_kt_buf = trtllm_kv_cache_manager.get_kt_buffers(layer_idx)
|
||||||
|
vanilla_kt_buf = vanilla_kv_cache_manager.get_kt_buffers(layer_idx)
|
||||||
|
|
||||||
|
torch.nn.init.normal_(trtllm_kt_buf)
|
||||||
|
|
||||||
|
# Map trtllm data to vanilla based on block offsets
|
||||||
|
# TRTLLM: (num_blocks, kt_tokens_per_block, num_kv_heads, 2*head_dim)
|
||||||
|
# VANILLA: (num_blocks, kt_tokens_per_block_vanilla, num_kv_heads, 2*head_dim)
|
||||||
|
trtllm_kt_tokens_per_block = trtllm_kv_cache_manager.kt_tokens_per_block
|
||||||
|
vanilla_kv_cache_manager.kt_tokens_per_block
|
||||||
|
|
||||||
|
trtllm_block_offsets = trtllm_metadata.kt_cache_block_offsets
|
||||||
|
vanilla_block_offsets = vanilla_metadata.kt_cache_block_offsets
|
||||||
|
|
||||||
|
for req_idx in range(num_contexts, batch_size):
|
||||||
|
# Get the number of KT tokens for this request
|
||||||
|
past_token = past_seen_tokens[req_idx]
|
||||||
|
num_kt_tokens = (past_token + 1 +
|
||||||
|
sparse_attn_config_trtllm.page_size -
|
||||||
|
1) // sparse_attn_config_trtllm.page_size
|
||||||
|
|
||||||
|
# Get block offsets for this request
|
||||||
|
trtllm_blocks = trtllm_block_offsets[req_idx]
|
||||||
|
vanilla_blocks = vanilla_block_offsets[req_idx]
|
||||||
|
|
||||||
|
# Copy data from trtllm blocks to vanilla blocks
|
||||||
|
kt_token_idx = 0
|
||||||
|
vanilla_block_idx = 0
|
||||||
|
|
||||||
|
# For trtllm: iterate through blocks and copy KT tokens
|
||||||
|
for trtllm_block_local_idx in range(len(trtllm_blocks)):
|
||||||
|
if kt_token_idx >= num_kt_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
trtllm_block = trtllm_blocks[trtllm_block_local_idx]
|
||||||
|
if trtllm_block < 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# How many KT tokens in this trtllm block
|
||||||
|
kt_tokens_in_this_block = min(trtllm_kt_tokens_per_block,
|
||||||
|
num_kt_tokens - kt_token_idx)
|
||||||
|
|
||||||
|
# Copy to vanilla buffer
|
||||||
|
vanilla_block = vanilla_blocks[vanilla_block_idx]
|
||||||
|
if vanilla_block >= 0:
|
||||||
|
vanilla_kt_buf[vanilla_block, kt_token_idx:kt_token_idx +
|
||||||
|
kt_tokens_in_this_block].copy_(trtllm_kt_buf[
|
||||||
|
trtllm_block, :kt_tokens_in_this_block])
|
||||||
|
|
||||||
|
kt_token_idx += kt_tokens_in_this_block
|
||||||
|
|
||||||
|
trtllm_sparse_attn_indices, trtllm_sparse_attn_offsets = trtllm_attn.sparse_attn_predict(
|
||||||
|
qkv, None, trtllm_metadata)
|
||||||
|
|
||||||
|
vanilla_sparse_attn_indices_list = []
|
||||||
|
offset = sum(seq_lens[:num_contexts]) # Skip context tokens
|
||||||
|
|
||||||
|
for i in range(num_contexts, batch_size):
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
single_qkv = qkv[offset:offset + seq_len]
|
||||||
|
q, k, _ = single_qkv.split([
|
||||||
|
num_heads * head_dim, num_kv_heads * head_dim,
|
||||||
|
num_kv_heads * head_dim
|
||||||
|
],
|
||||||
|
dim=-1)
|
||||||
|
q = q.view(1, seq_len, num_heads, head_dim).transpose(1, 2)
|
||||||
|
k = k.view(1, seq_len, num_kv_heads, head_dim)
|
||||||
|
|
||||||
|
past_seen_token = past_seen_tokens[i]
|
||||||
|
vanilla_indices = vanilla_attn._rocketkv_selection(
|
||||||
|
q, k, vanilla_metadata, past_seen_token, i)
|
||||||
|
vanilla_sparse_attn_indices_list.append(vanilla_indices.squeeze(0))
|
||||||
|
offset += seq_len
|
||||||
|
|
||||||
|
if trtllm_sparse_attn_indices is not None:
|
||||||
|
assert len(vanilla_sparse_attn_indices_list
|
||||||
|
) > 0, "Vanilla should also produce indices"
|
||||||
|
|
||||||
|
vanilla_sparse_attn_indices = torch.cat(
|
||||||
|
vanilla_sparse_attn_indices_list, dim=0).transpose(0,
|
||||||
|
1).contiguous()
|
||||||
|
|
||||||
|
# Apply interleave operation to trtllm indices
|
||||||
|
# For each head, multiply indices by page_size and expand to include all tokens in each page
|
||||||
|
page_size = sparse_attn_config_trtllm.page_size
|
||||||
|
num_kv_heads, total_indices = trtllm_sparse_attn_indices.shape
|
||||||
|
|
||||||
|
interleaved_indices_list = []
|
||||||
|
for head_idx in range(num_kv_heads):
|
||||||
|
head_indices = trtllm_sparse_attn_indices[
|
||||||
|
head_idx] # Shape: [total_indices]
|
||||||
|
|
||||||
|
page_starts = head_indices * page_size # Shape: [total_indices]
|
||||||
|
|
||||||
|
expanded_indices = []
|
||||||
|
for page_start in page_starts:
|
||||||
|
page_indices = torch.arange(page_start,
|
||||||
|
page_start + page_size,
|
||||||
|
device=page_starts.device)
|
||||||
|
expanded_indices.append(page_indices)
|
||||||
|
|
||||||
|
head_interleaved = torch.cat(expanded_indices, dim=0)
|
||||||
|
|
||||||
|
# Slice to match vanilla shape
|
||||||
|
target_length = vanilla_sparse_attn_indices.shape[1]
|
||||||
|
head_interleaved = head_interleaved[:target_length]
|
||||||
|
|
||||||
|
interleaved_indices_list.append(head_interleaved)
|
||||||
|
|
||||||
|
# Stack all heads
|
||||||
|
trtllm_sparse_attn_indices = torch.stack(interleaved_indices_list,
|
||||||
|
dim=0)
|
||||||
|
|
||||||
|
assert trtllm_sparse_attn_indices.shape == vanilla_sparse_attn_indices.shape, \
|
||||||
|
f"Shape mismatch: {trtllm_sparse_attn_indices.shape} vs {vanilla_sparse_attn_indices.shape}"
|
||||||
|
|
||||||
|
trtllm_sparse_attn_indices = trtllm_sparse_attn_indices.sort(
|
||||||
|
dim=-1).values
|
||||||
|
vanilla_sparse_attn_indices = vanilla_sparse_attn_indices.sort(
|
||||||
|
dim=-1).values
|
||||||
|
|
||||||
|
# Check indices overlap per batch and per head
|
||||||
|
num_kv_heads = trtllm_sparse_attn_indices.shape[0]
|
||||||
|
|
||||||
|
trtllm_offsets = trtllm_sparse_attn_offsets.cpu().tolist()
|
||||||
|
|
||||||
|
overlap_ratios = []
|
||||||
|
batch_overlap_details = []
|
||||||
|
|
||||||
|
num_generations = batch_size - num_contexts
|
||||||
|
for batch_idx in range(num_generations):
|
||||||
|
start_idx = trtllm_offsets[batch_idx]
|
||||||
|
end_idx = trtllm_offsets[batch_idx + 1]
|
||||||
|
end_idx - start_idx
|
||||||
|
|
||||||
|
batch_overlaps = []
|
||||||
|
for head_idx in range(num_kv_heads):
|
||||||
|
trtllm_batch = trtllm_sparse_attn_indices[
|
||||||
|
head_idx, start_idx:end_idx].cpu().tolist()
|
||||||
|
vanilla_batch = vanilla_sparse_attn_indices[
|
||||||
|
head_idx, start_idx:end_idx].cpu().tolist()
|
||||||
|
|
||||||
|
trtllm_set = set(trtllm_batch)
|
||||||
|
vanilla_set = set(vanilla_batch)
|
||||||
|
|
||||||
|
# Calculate overlap
|
||||||
|
overlap = len(vanilla_set & trtllm_set)
|
||||||
|
overlap_ratio = overlap / len(vanilla_set) if len(
|
||||||
|
vanilla_set) > 0 else 1.0
|
||||||
|
batch_overlaps.append(overlap_ratio)
|
||||||
|
overlap_ratios.append(overlap_ratio)
|
||||||
|
|
||||||
|
avg_batch_overlap = sum(batch_overlaps) / len(batch_overlaps)
|
||||||
|
batch_overlap_details.append(
|
||||||
|
f"Batch {batch_idx}: {avg_batch_overlap:.4f}")
|
||||||
|
|
||||||
|
avg_overlap_ratio = sum(overlap_ratios) / len(overlap_ratios)
|
||||||
|
print(f"Average overlap ratio: {avg_overlap_ratio:.4f}")
|
||||||
|
print(f"Per-batch average: {batch_overlap_details}")
|
||||||
|
|
||||||
|
threshold = 0.94
|
||||||
|
assert avg_overlap_ratio >= threshold, \
|
||||||
|
f"Indices overlap ratio {avg_overlap_ratio:.4f} is too low (< {threshold})"
|
||||||
|
else:
|
||||||
|
assert len(
|
||||||
|
vanilla_sparse_attn_indices_list
|
||||||
|
) == 0, "Both should return None when no sparse attention is needed"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# RocketKV e2e tests
|
||||||
|
print("=== Testing RocketKV E2E tests ===")
|
||||||
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "VANILLA")
|
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "VANILLA")
|
||||||
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "TRTLLM")
|
test_model("pytorch", "llama-3.1-model/Llama-3.1-8B-Instruct", "TRTLLM")
|
||||||
|
|
||||||
|
# Unit tests for sparse_kv_predict
|
||||||
|
print("\n=== Testing sparse_kv_predict ===")
|
||||||
|
test_sparse_kv_predict(1, 1) # bs=1, context only
|
||||||
|
test_sparse_kv_predict(2, 2) # bs=2, context only
|
||||||
|
test_sparse_kv_predict(6, 3) # bs=6, mixed (3 ctx + 3 gen)
|
||||||
|
|
||||||
|
# Unit tests for sparse_attn_predict
|
||||||
|
print("\n=== Testing sparse_attn_predict ===")
|
||||||
|
test_sparse_attn_predict(1, 0) # bs=1, generation only
|
||||||
|
test_sparse_attn_predict(2, 0) # bs=2, generation only
|
||||||
|
test_sparse_attn_predict(3, 0) # bs=3, generation only
|
||||||
|
test_sparse_attn_predict(5, 3) # bs=5, mixed (3 ctx + 2 gen)
|
||||||
|
test_sparse_attn_predict(6, 2) # bs=6, mixed (2 ctx + 4 gen)
|
||||||
|
|||||||
413
tests/unittest/_torch/attention/sparse/test_triton_bmm.py
Normal file
413
tests/unittest/_torch/attention/sparse/test_triton_bmm.py
Normal file
@ -0,0 +1,413 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tensorrt_llm._torch.attention_backend.sparse.kernel import (
|
||||||
|
triton_bmm,
|
||||||
|
triton_rocket_paged_kt_cache_bmm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_reference_bmm(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
q_cu_seqlens: torch.Tensor,
|
||||||
|
k_cu_seqlens: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
sm_scale: float = None,
|
||||||
|
causal: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_q_heads, total_q_tokens, head_dim = q.shape
|
||||||
|
num_k_heads, total_k_tokens, _ = k.shape
|
||||||
|
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = 1.0 / math.sqrt(head_dim)
|
||||||
|
|
||||||
|
# Compute q_len_per_seq
|
||||||
|
q_len_per_seq = total_q_tokens // batch_size
|
||||||
|
|
||||||
|
scores = torch.full(
|
||||||
|
(num_q_heads, q_len_per_seq, total_k_tokens),
|
||||||
|
float("-inf"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=q.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process each batch
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
q_start = q_cu_seqlens[batch_idx].item()
|
||||||
|
q_end = q_cu_seqlens[batch_idx + 1].item()
|
||||||
|
k_start = k_cu_seqlens[batch_idx].item()
|
||||||
|
k_end = k_cu_seqlens[batch_idx + 1].item()
|
||||||
|
|
||||||
|
q_seqlen = q_end - q_start
|
||||||
|
k_seqlen = k_end - k_start
|
||||||
|
|
||||||
|
if q_seqlen <= 0 or k_seqlen <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
q_batch = q[:, q_start:q_end, :] # [num_q_heads, q_seqlen, head_dim]
|
||||||
|
|
||||||
|
num_heads_per_kv = num_q_heads // num_k_heads
|
||||||
|
|
||||||
|
for head_idx in range(num_q_heads):
|
||||||
|
k_head_idx = head_idx // num_heads_per_kv
|
||||||
|
k_batch = k[k_head_idx, k_start:k_end, :] # [k_seqlen, head_dim]
|
||||||
|
|
||||||
|
qk = torch.matmul(q_batch[head_idx], k_batch.T) * sm_scale
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
causal_mask = torch.triu(
|
||||||
|
torch.ones(q_seqlen, k_seqlen, device=q.device, dtype=torch.bool), diagonal=1
|
||||||
|
)
|
||||||
|
qk = qk.masked_fill(causal_mask, float("-inf"))
|
||||||
|
|
||||||
|
scores[head_idx, :q_seqlen, k_start:k_end] = qk
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def create_kt_cache_from_k(
|
||||||
|
k: torch.Tensor,
|
||||||
|
kv_lens: torch.Tensor,
|
||||||
|
kt_page_size: int,
|
||||||
|
tokens_per_block: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
|
"""
|
||||||
|
Create paged KT cache tensor from continuous K tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k: Key tensor [num_gen_tokens, num_kv_heads * head_dim]
|
||||||
|
kv_lens: Sequence lengths [num_gen_tokens]
|
||||||
|
kt_page_size: Page size for KT tokens
|
||||||
|
tokens_per_block: Tokens per cache block
|
||||||
|
num_kv_heads: Number of KV heads
|
||||||
|
head_dim: Head dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
kt_cache_tensor: KT cache [num_blocks, tokens_per_block, num_kv_heads, 2*head_dim]
|
||||||
|
kt_cache_block_offsets: Block offsets [num_gen_tokens, max_kt_blocks_per_seq]
|
||||||
|
max_kt_blocks_per_seq: Maximum KT blocks per sequence
|
||||||
|
"""
|
||||||
|
num_gen_tokens = k.shape[0]
|
||||||
|
device = k.device
|
||||||
|
dtype = k.dtype
|
||||||
|
|
||||||
|
# Calculate number of kt tokens per sequence
|
||||||
|
num_kt_tokens_per_seq = [
|
||||||
|
(kv_len.item() + kt_page_size - 1) // kt_page_size for kv_len in kv_lens
|
||||||
|
]
|
||||||
|
max_kt_tokens = max(num_kt_tokens_per_seq)
|
||||||
|
max_kt_blocks_per_seq = (max_kt_tokens + tokens_per_block - 1) // tokens_per_block
|
||||||
|
|
||||||
|
# Calculate total number of blocks needed
|
||||||
|
total_blocks_needed = sum(
|
||||||
|
(kt_tokens + tokens_per_block - 1) // tokens_per_block
|
||||||
|
for kt_tokens in num_kt_tokens_per_seq
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create KT cache tensor
|
||||||
|
kt_cache_tensor = torch.zeros(
|
||||||
|
(total_blocks_needed, tokens_per_block, num_kv_heads, 2 * head_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create block offsets tensor
|
||||||
|
kt_cache_block_offsets = torch.full(
|
||||||
|
(num_gen_tokens, max_kt_blocks_per_seq), -1, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fill KT cache and block offsets
|
||||||
|
current_block_idx = 0
|
||||||
|
|
||||||
|
for seq_idx in range(num_gen_tokens):
|
||||||
|
kv_len = kv_lens[seq_idx].item()
|
||||||
|
num_kt_tokens = num_kt_tokens_per_seq[seq_idx]
|
||||||
|
|
||||||
|
# Reshape k for this sequence: [num_kv_heads, head_dim]
|
||||||
|
k_seq = k[seq_idx].view(num_kv_heads, head_dim)
|
||||||
|
|
||||||
|
# Process each kt token (page)
|
||||||
|
for kt_idx in range(num_kt_tokens):
|
||||||
|
page_start = kt_idx * kt_page_size
|
||||||
|
|
||||||
|
# For simplicity, we use the first token in the page as representative
|
||||||
|
# In real usage, this would be min/max over the page
|
||||||
|
# Here we just replicate the first token's value for testing
|
||||||
|
token_idx = page_start
|
||||||
|
if token_idx < kv_len:
|
||||||
|
k_val = k_seq # [num_kv_heads, head_dim]
|
||||||
|
|
||||||
|
# Store k_min and k_max (for testing, we use same value)
|
||||||
|
kt_min = k_val
|
||||||
|
kt_max = k_val
|
||||||
|
|
||||||
|
# Determine which block this kt token belongs to
|
||||||
|
block_offset = kt_idx // tokens_per_block
|
||||||
|
token_offset_in_block = kt_idx % tokens_per_block
|
||||||
|
|
||||||
|
# Assign block index if not already assigned
|
||||||
|
if kt_cache_block_offsets[seq_idx, block_offset] < 0:
|
||||||
|
kt_cache_block_offsets[seq_idx, block_offset] = current_block_idx
|
||||||
|
current_block_idx += 1
|
||||||
|
|
||||||
|
block_idx = kt_cache_block_offsets[seq_idx, block_offset].item()
|
||||||
|
|
||||||
|
# Store in cache: [block, token_in_block, head, 2*head_dim]
|
||||||
|
kt_cache_tensor[block_idx, token_offset_in_block, :, :head_dim] = kt_min
|
||||||
|
kt_cache_tensor[block_idx, token_offset_in_block, :, head_dim:] = kt_max
|
||||||
|
|
||||||
|
return kt_cache_tensor, kt_cache_block_offsets, max_kt_blocks_per_seq
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_reference_paged_kt_cache_bmm(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
dim_pos: torch.Tensor,
|
||||||
|
kv_lens: torch.Tensor,
|
||||||
|
kt_page_size: int,
|
||||||
|
sm_scale: float = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim = q.shape
|
||||||
|
total_num_heads = num_kv_heads * num_heads_per_kv
|
||||||
|
device = q.device
|
||||||
|
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = 1.0 / math.sqrt(head_dim)
|
||||||
|
|
||||||
|
max_kt_tokens = max((kv_len.item() + kt_page_size - 1) // kt_page_size for kv_len in kv_lens)
|
||||||
|
total_kt_tokens = num_gen_tokens * max_kt_tokens
|
||||||
|
|
||||||
|
scores = torch.zeros((total_num_heads, 1, total_kt_tokens), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# Process each generation token
|
||||||
|
for batch_idx in range(num_gen_tokens):
|
||||||
|
kv_len = kv_lens[batch_idx].item()
|
||||||
|
num_kt_tokens = (kv_len + kt_page_size - 1) // kt_page_size
|
||||||
|
|
||||||
|
q_batch = q[batch_idx] # [num_kv_heads, num_heads_per_kv, head_dim]
|
||||||
|
k_batch = k[batch_idx].view(num_kv_heads, head_dim) # [num_kv_heads, head_dim]
|
||||||
|
dim_pos_batch = dim_pos[batch_idx] # [num_kv_heads, head_dim]
|
||||||
|
|
||||||
|
output_offset = batch_idx * max_kt_tokens
|
||||||
|
|
||||||
|
for kv_head_idx in range(num_kv_heads):
|
||||||
|
for q_head_idx in range(num_heads_per_kv):
|
||||||
|
global_head_idx = kv_head_idx * num_heads_per_kv + q_head_idx
|
||||||
|
|
||||||
|
q_vec = q_batch[kv_head_idx, q_head_idx] # [head_dim]
|
||||||
|
k_vec = k_batch[kv_head_idx] # [head_dim]
|
||||||
|
dim_pos_vec = dim_pos_batch[kv_head_idx] # [head_dim]
|
||||||
|
|
||||||
|
# Simulate KT selection based on dim_pos
|
||||||
|
k_selected = torch.where(dim_pos_vec > 0, k_vec, k_vec)
|
||||||
|
|
||||||
|
# Compute score for each kt token (simplified)
|
||||||
|
for kt_idx in range(num_kt_tokens):
|
||||||
|
score = torch.dot(q_vec, k_selected) * sm_scale
|
||||||
|
scores[global_head_idx, 0, output_offset + kt_idx] = score
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch_size,q_len_per_seq,k_lens,num_q_heads,num_kv_heads,head_dim,causal",
|
||||||
|
[
|
||||||
|
# Single batch
|
||||||
|
(1, 32, [128], 8, 8, 128, False),
|
||||||
|
(1, 32, [128], 8, 8, 128, True),
|
||||||
|
# Multiple batches with different k_len
|
||||||
|
(3, 32, [64, 128, 256], 8, 8, 128, False),
|
||||||
|
(4, 16, [100, 200, 150, 300], 32, 8, 128, False),
|
||||||
|
# Edge cases
|
||||||
|
(2, 1, [10, 20], 8, 8, 128, False), # q_len=1
|
||||||
|
(2, 64, [64, 128], 16, 4, 64, True), # Different head_dim with causal
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_triton_bmm(batch_size, q_len_per_seq, k_lens, num_q_heads, num_kv_heads, head_dim, causal):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
total_q_tokens = batch_size * q_len_per_seq
|
||||||
|
q_cu_seqlens = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * q_len_per_seq
|
||||||
|
|
||||||
|
k_lens_tensor = torch.tensor(k_lens, dtype=torch.int32, device=device)
|
||||||
|
k_cu_seqlens = torch.cat(
|
||||||
|
[torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(k_lens_tensor, dim=0)]
|
||||||
|
)
|
||||||
|
total_k_tokens = k_cu_seqlens[-1].item()
|
||||||
|
|
||||||
|
q = torch.randn((num_q_heads, total_q_tokens, head_dim), dtype=dtype, device=device)
|
||||||
|
k = torch.randn((num_kv_heads, total_k_tokens, head_dim), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
triton_scores = triton_bmm(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
q_cu_seqlens=q_cu_seqlens,
|
||||||
|
k_cu_seqlens=k_cu_seqlens,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sm_scale=None,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
reference_scores = pytorch_reference_bmm(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
q_cu_seqlens=q_cu_seqlens,
|
||||||
|
k_cu_seqlens=k_cu_seqlens,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sm_scale=None,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
# Handle -inf values separately
|
||||||
|
triton_finite = torch.isfinite(triton_scores)
|
||||||
|
reference_finite = torch.isfinite(reference_scores)
|
||||||
|
|
||||||
|
# Check that inf/finite masks match
|
||||||
|
assert torch.all(triton_finite == reference_finite), (
|
||||||
|
"Finite/infinite mask mismatch between Triton and reference"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare finite values
|
||||||
|
if triton_finite.any():
|
||||||
|
max_diff = torch.max(
|
||||||
|
torch.abs(triton_scores[triton_finite] - reference_scores[reference_finite])
|
||||||
|
).item()
|
||||||
|
|
||||||
|
print(f"Max absolute difference: {max_diff:.6f}")
|
||||||
|
|
||||||
|
assert max_diff < 0.01, f"Max difference {max_diff} exceeds threshold"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch_size,kv_lens,num_kv_heads,num_heads_per_kv,head_dim,kt_page_size,tokens_per_block",
|
||||||
|
[
|
||||||
|
# Single batch
|
||||||
|
(1, [128], 8, 4, 128, 3, 64),
|
||||||
|
# Multiple batches with different kv_len
|
||||||
|
(3, [100, 200, 150], 8, 4, 128, 3, 64),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_triton_rocket_paged_kt_cache_bmm(
|
||||||
|
batch_size, kv_lens, num_kv_heads, num_heads_per_kv, head_dim, kt_page_size, tokens_per_block
|
||||||
|
):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
num_gen_tokens = batch_size
|
||||||
|
|
||||||
|
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
# Create Q tensor: [num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim]
|
||||||
|
q = torch.randn(
|
||||||
|
(num_gen_tokens, num_kv_heads, num_heads_per_kv, head_dim), dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create K tensor for reference: [num_gen_tokens, num_kv_heads * head_dim]
|
||||||
|
k = torch.randn((num_gen_tokens, num_kv_heads * head_dim), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Create dim_pos: [num_gen_tokens, num_kv_heads, head_dim]
|
||||||
|
# Randomly set some dimensions to head_dim (positive) and others to 0
|
||||||
|
dim_pos = torch.zeros(
|
||||||
|
(num_gen_tokens, num_kv_heads, head_dim), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
for i in range(num_gen_tokens):
|
||||||
|
for j in range(num_kv_heads):
|
||||||
|
num_positive = torch.randint(0, head_dim, (1,)).item()
|
||||||
|
positive_indices = torch.randperm(head_dim)[:num_positive]
|
||||||
|
dim_pos[i, j, positive_indices] = head_dim
|
||||||
|
|
||||||
|
# Create paged KT cache
|
||||||
|
kt_cache_tensor, kt_cache_block_offsets, max_kt_blocks_per_seq = create_kt_cache_from_k(
|
||||||
|
k=k,
|
||||||
|
kv_lens=kv_lens_tensor,
|
||||||
|
kt_page_size=kt_page_size,
|
||||||
|
tokens_per_block=tokens_per_block,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate output offsets
|
||||||
|
max_kt_tokens = max((kv_len + kt_page_size - 1) // kt_page_size for kv_len in kv_lens)
|
||||||
|
total_kt_tokens = num_gen_tokens * max_kt_tokens
|
||||||
|
output_offsets = (
|
||||||
|
torch.arange(0, num_gen_tokens + 1, device=device, dtype=torch.int32) * max_kt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
triton_scores = triton_rocket_paged_kt_cache_bmm(
|
||||||
|
q=q,
|
||||||
|
kt_cache_tensor=kt_cache_tensor,
|
||||||
|
kt_cache_block_offsets=kt_cache_block_offsets,
|
||||||
|
dim_pos=dim_pos,
|
||||||
|
kv_lens=kv_lens_tensor,
|
||||||
|
output_offsets=output_offsets,
|
||||||
|
kt_page_size=kt_page_size,
|
||||||
|
tokens_per_block=tokens_per_block,
|
||||||
|
max_kt_blocks_per_seq=max_kt_blocks_per_seq,
|
||||||
|
total_kt_tokens=total_kt_tokens,
|
||||||
|
sm_scale=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
reference_scores = pytorch_reference_paged_kt_cache_bmm(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
dim_pos=dim_pos,
|
||||||
|
kv_lens=kv_lens_tensor,
|
||||||
|
kt_page_size=kt_page_size,
|
||||||
|
sm_scale=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
# Only compare non-zero entries (valid kt tokens)
|
||||||
|
mask = torch.zeros_like(triton_scores, dtype=torch.bool)
|
||||||
|
for batch_idx in range(num_gen_tokens):
|
||||||
|
kv_len = kv_lens[batch_idx]
|
||||||
|
num_kt_tokens = (kv_len + kt_page_size - 1) // kt_page_size
|
||||||
|
offset = batch_idx * max_kt_tokens
|
||||||
|
mask[:, :, offset : offset + num_kt_tokens] = True
|
||||||
|
|
||||||
|
triton_valid = triton_scores[mask]
|
||||||
|
reference_valid = reference_scores[mask]
|
||||||
|
|
||||||
|
max_diff = torch.max(torch.abs(triton_valid - reference_valid)).item()
|
||||||
|
|
||||||
|
print(f"Max absolute difference: {max_diff:.6f}")
|
||||||
|
|
||||||
|
assert max_diff < 0.05, f"Max difference {max_diff} exceeds threshold"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Testing Triton BMM Kernel")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Test triton_bmm
|
||||||
|
print("\n--- Single batch, non-causal ---")
|
||||||
|
test_triton_bmm(1, 32, [128], 8, 8, 128, False)
|
||||||
|
|
||||||
|
print("\n--- Single batch, causal ---")
|
||||||
|
test_triton_bmm(1, 32, [128], 8, 8, 128, True)
|
||||||
|
|
||||||
|
print("\n--- Multiple batches, different k_len ---")
|
||||||
|
test_triton_bmm(3, 32, [64, 128, 256], 8, 8, 128, False)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Testing Triton Rocket Paged KT Cache BMM Kernel")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Test triton_rocket_paged_kt_cache_bmm
|
||||||
|
print("\n--- Single batch ---")
|
||||||
|
test_triton_rocket_paged_kt_cache_bmm(1, [128], 8, 4, 128, 3, 64)
|
||||||
|
|
||||||
|
print("\n--- Multiple batches, different kv_len ---")
|
||||||
|
test_triton_rocket_paged_kt_cache_bmm(3, [100, 200, 150], 8, 4, 128, 3, 64)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("All tests passed!")
|
||||||
|
print("=" * 80)
|
||||||
228
tests/unittest/_torch/attention/sparse/test_triton_topk.py
Normal file
228
tests/unittest/_torch/attention/sparse/test_triton_topk.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tensorrt_llm._torch.attention_backend.sparse.kernel import triton_topk
|
||||||
|
|
||||||
|
|
||||||
|
def pytorch_reference_topk(
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
kv_offsets: torch.Tensor,
|
||||||
|
kv_lens: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_tensor: Input scores [num_kv_heads, sum(kv_lens)]
|
||||||
|
kv_offsets: KV offsets [batch_size + 1]
|
||||||
|
kv_lens: KV lengths [batch_size]
|
||||||
|
topk: TopK parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output_indices: Padded indices [num_kv_heads, batch_size, topk]
|
||||||
|
"""
|
||||||
|
num_kv_heads = input_tensor.shape[0]
|
||||||
|
batch_size = len(kv_lens)
|
||||||
|
device = input_tensor.device
|
||||||
|
|
||||||
|
# Compute max sequence length for padding
|
||||||
|
max_seq_len = kv_lens.max().item()
|
||||||
|
|
||||||
|
# Ensure padding size >= topk
|
||||||
|
pad_size = max(max_seq_len, topk)
|
||||||
|
|
||||||
|
# Create padded tensor [num_kv_heads, batch_size, pad_size]
|
||||||
|
padded_tensor = torch.full(
|
||||||
|
(num_kv_heads, batch_size, pad_size), float("-inf"), dtype=input_tensor.dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fill in actual values
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
start = kv_offsets[batch_idx].item()
|
||||||
|
end = kv_offsets[batch_idx + 1].item()
|
||||||
|
seq_len = kv_lens[batch_idx].item()
|
||||||
|
|
||||||
|
for head_idx in range(num_kv_heads):
|
||||||
|
padded_tensor[head_idx, batch_idx, :seq_len] = input_tensor[head_idx, start:end]
|
||||||
|
|
||||||
|
# Perform batch topk: [num_kv_heads, batch_size, pad_size] -> [num_kv_heads, batch_size, topk]
|
||||||
|
topk_values, topk_indices = torch.topk(
|
||||||
|
padded_tensor,
|
||||||
|
k=topk,
|
||||||
|
dim=-1,
|
||||||
|
largest=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mask out invalid indices based on each batch's seq_len
|
||||||
|
seq_lens_expanded = kv_lens.to(device).unsqueeze(0).unsqueeze(-1) # [1, batch_size, 1]
|
||||||
|
# topk_indices: [num_kv_heads, batch_size, topk]
|
||||||
|
mask = topk_indices >= seq_lens_expanded
|
||||||
|
topk_indices.masked_fill_(mask, -1)
|
||||||
|
|
||||||
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
|
def triton_topk_wrapper(
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
kv_offsets: torch.Tensor,
|
||||||
|
kv_lens: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_tensor: Input scores [num_kv_heads, sum(kv_lens)]
|
||||||
|
kv_offsets: KV offsets [batch_size + 1]
|
||||||
|
kv_lens: KV lengths [batch_size]
|
||||||
|
topk: TopK parameter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output_indices: Padded indices [num_kv_heads, batch_size, topk]
|
||||||
|
"""
|
||||||
|
num_kv_heads = input_tensor.shape[0]
|
||||||
|
batch_size = len(kv_lens)
|
||||||
|
device = input_tensor.device
|
||||||
|
|
||||||
|
sparse_lens = torch.tensor(
|
||||||
|
[min(topk, seq_len.item()) for seq_len in kv_lens], dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
sparse_offsets = torch.cat(
|
||||||
|
[torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(sparse_lens, dim=0)]
|
||||||
|
).to(device)
|
||||||
|
total_sparse_indices = sparse_offsets[-1].item()
|
||||||
|
|
||||||
|
output_indices_flat = triton_topk(
|
||||||
|
input_tensor, kv_offsets, sparse_offsets, total_sparse_indices, topk
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert flat format to padded format [num_kv_heads, batch_size, topk]
|
||||||
|
output_indices_padded = torch.full(
|
||||||
|
(num_kv_heads, batch_size, topk), -1, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
start = sparse_offsets[batch_idx].item()
|
||||||
|
end = sparse_offsets[batch_idx + 1].item()
|
||||||
|
actual_len = end - start
|
||||||
|
|
||||||
|
for head_idx in range(num_kv_heads):
|
||||||
|
output_indices_padded[head_idx, batch_idx, :actual_len] = output_indices_flat[
|
||||||
|
head_idx, start:end
|
||||||
|
]
|
||||||
|
|
||||||
|
return output_indices_padded
|
||||||
|
|
||||||
|
|
||||||
|
def compute_overlap_ratio(
|
||||||
|
triton_indices: torch.Tensor,
|
||||||
|
reference_indices: torch.Tensor,
|
||||||
|
kv_lens: torch.Tensor,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
triton_indices: Triton topk results [num_kv_heads, batch_size, topk]
|
||||||
|
reference_indices: Reference topk results [num_kv_heads, batch_size, topk]
|
||||||
|
kv_lens: KV lengths [batch_size]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Average overlap ratio across all batches and heads
|
||||||
|
"""
|
||||||
|
num_kv_heads = triton_indices.shape[0]
|
||||||
|
batch_size = triton_indices.shape[1]
|
||||||
|
|
||||||
|
overlap_ratios = []
|
||||||
|
|
||||||
|
# Compare batch by batch
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
for head_idx in range(num_kv_heads):
|
||||||
|
# Extract indices for this batch and head
|
||||||
|
triton_batch = triton_indices[head_idx, batch_idx, :].cpu().tolist()
|
||||||
|
reference_batch = reference_indices[head_idx, batch_idx, :].cpu().tolist()
|
||||||
|
|
||||||
|
# Filter out -1 (invalid/padding indices)
|
||||||
|
triton_set = set([x for x in triton_batch if x >= 0])
|
||||||
|
reference_set = set([x for x in reference_batch if x >= 0])
|
||||||
|
|
||||||
|
if len(reference_set) > 0:
|
||||||
|
overlap = len(triton_set & reference_set)
|
||||||
|
overlap_ratio = overlap / len(reference_set)
|
||||||
|
overlap_ratios.append(overlap_ratio)
|
||||||
|
|
||||||
|
if len(overlap_ratios) == 0:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
return sum(overlap_ratios) / len(overlap_ratios)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch_size,seq_lens,num_kv_heads,topk",
|
||||||
|
[
|
||||||
|
# Single batch, seq_len > topk
|
||||||
|
(1, [3000], 1, 2048),
|
||||||
|
# Single batch, seq_len < topk
|
||||||
|
(1, [1000], 8, 2048),
|
||||||
|
# Multiple batches, mixed seq_len (some < topk, some > topk)
|
||||||
|
(6, [50, 150, 80, 300, 100, 256], 8, 128),
|
||||||
|
(6, [1000, 2500, 1500, 3000, 1800, 4000], 8, 2048),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_topk_kernel(batch_size, seq_lens, num_kv_heads, topk):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
kv_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
kv_offsets = torch.cat(
|
||||||
|
[torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(kv_lens, dim=0)]
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
total_tokens = kv_offsets[-1].item()
|
||||||
|
|
||||||
|
input_tensor = torch.randn((num_kv_heads, total_tokens), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
triton_output = triton_topk_wrapper(
|
||||||
|
input_tensor=input_tensor,
|
||||||
|
kv_offsets=kv_offsets,
|
||||||
|
kv_lens=kv_lens,
|
||||||
|
topk=topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
reference_output = pytorch_reference_topk(
|
||||||
|
input_tensor=input_tensor,
|
||||||
|
kv_offsets=kv_offsets,
|
||||||
|
kv_lens=kv_lens,
|
||||||
|
topk=topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
overlap_ratio = compute_overlap_ratio(
|
||||||
|
triton_output,
|
||||||
|
reference_output,
|
||||||
|
kv_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
min_threshold = 0.99
|
||||||
|
|
||||||
|
print(f"overlap_ratio: {overlap_ratio}")
|
||||||
|
|
||||||
|
assert overlap_ratio >= min_threshold, (
|
||||||
|
f"Overlap ratio {overlap_ratio:.4f} is too low (< {min_threshold})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Testing Triton TopK Kernel")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Single batch tests
|
||||||
|
print("\n--- Single batch, seq_len > topk ---")
|
||||||
|
test_topk_kernel(1, [3000], 8, 2048)
|
||||||
|
|
||||||
|
print("\n--- Single batch, seq_len < topk ---")
|
||||||
|
test_topk_kernel(1, [1000], 8, 2048)
|
||||||
|
|
||||||
|
print("\n--- Multiple batches, mixed seq_len ---")
|
||||||
|
test_topk_kernel(6, [50, 150, 80, 300, 100, 256], 8, 128)
|
||||||
|
test_topk_kernel(6, [1000, 2500, 1500, 3000, 1800, 4000], 8, 2048)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("All tests passed!")
|
||||||
|
print("=" * 80)
|
||||||
Loading…
Reference in New Issue
Block a user