mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
88 lines
3.1 KiB
C++
88 lines
3.1 KiB
C++
/*
|
|
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/common/config.h"
|
|
#include <cstdint>
|
|
#include <cuda_runtime.h>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <tuple>
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
|
|
namespace kernels
|
|
{
|
|
|
|
struct SparseAttentionParams
|
|
{
|
|
int32_t* sparse_kv_indices{nullptr}; // [num_kv_heads, num_sparse_kv_indices]
|
|
int32_t* sparse_attn_indices{nullptr}; // [num_kv_heads, num_sparse_attn_indices]
|
|
int32_t* sparse_kv_offsets{nullptr}; // [num_contexts + 1]
|
|
int32_t* sparse_attn_offsets{nullptr}; // [num_generations + 1]
|
|
int32_t sparse_mla_topk{0}; // for DSA attention
|
|
void* sparse_mla_kv_cache_pool{nullptr}; // for DSA attention
|
|
|
|
int32_t sparse_attn_indices_block_size{1};
|
|
int32_t sparse_attn_indices_stride{0};
|
|
|
|
std::string toString() const
|
|
{
|
|
std::stringstream ss;
|
|
ss << "sparse_kv_indices: " << this->sparse_kv_indices << std::endl
|
|
<< "sparse_attn_indices: " << this->sparse_attn_indices << std::endl
|
|
<< "sparse_kv_offsets: " << this->sparse_kv_offsets << std::endl
|
|
<< "sparse_attn_offsets: " << this->sparse_attn_offsets << std::endl
|
|
<< "sparse_mla_topk: " << this->sparse_mla_topk << std::endl
|
|
<< "sparse_mla_kv_cache_pool: " << this->sparse_mla_kv_cache_pool << std::endl
|
|
<< "sparse_attn_indices_block_size: " << this->sparse_attn_indices_block_size << std::endl
|
|
<< "sparse_attn_indices_stride: " << this->sparse_attn_indices_stride << std::endl;
|
|
return ss.str();
|
|
}
|
|
};
|
|
|
|
struct Pair
|
|
{
|
|
int32_t max_val;
|
|
int32_t sum_val;
|
|
};
|
|
|
|
struct PairReduceOp
|
|
{
|
|
#if defined(__CUDACC__)
|
|
inline __device__
|
|
#endif
|
|
Pair
|
|
operator()(Pair const& a, Pair const& b) const
|
|
{
|
|
Pair result;
|
|
result.max_val = a.max_val > b.max_val ? a.max_val : b.max_val;
|
|
result.sum_val = a.sum_val + b.sum_val;
|
|
return result;
|
|
}
|
|
};
|
|
|
|
void invokeGatherKvPageOffsets(int32_t* output_kv_page_offsets, // [num_head_kv, batch_size, 2, max_num_pages_per_seq]
|
|
int32_t* output_seq_lengths, // [num_head_kv, batch_size]
|
|
int32_t const* kv_page_offsets, // [batch_size, 2, max_num_pages_per_seq]
|
|
int32_t const* seq_lengths, // [batch_size]
|
|
SparseAttentionParams const sparse_params, int32_t const batch_size, int32_t const num_head_kv,
|
|
int32_t const tokens_per_page, int32_t const max_num_pages_per_seq, cudaStream_t stream);
|
|
|
|
} // namespace kernels
|
|
|
|
TRTLLM_NAMESPACE_END
|