[None][fix] Remove unused params in attn (#10652)

Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
Yi Zhang 2026-01-20 16:08:59 +08:00 committed by GitHub
parent 47e0ec2527
commit 58311b2345
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 110 additions and 142 deletions

View File

@ -127,7 +127,6 @@ public:
public:
// Attention packed mask input (used by context FMHA).
uint32_t const* attention_packed_mask = nullptr;
kernels::KVBlockArray::DataType* host_block_offsets = nullptr;
int32_t batch_size = 0;
float2 const* mrope_rotary_cos_sin = nullptr;
@ -182,7 +181,6 @@ public:
ss << "context_buf_sf: " << this->context_buf_sf << std::endl;
ss << "key_value_cache: " << (half*) this->key_value_cache << std::endl;
ss << "block_offsets: " << this->block_offsets << std::endl;
ss << "host_block_offsets: " << this->host_block_offsets << std::endl;
ss << "host_primary_pool_pointer: " << this->host_primary_pool_pointer << std::endl;
ss << "host_secondary_pool_pointer: " << this->host_secondary_pool_pointer << std::endl;
ss << "batch_size: " << this->batch_size << std::endl;

View File

@ -42,19 +42,19 @@ void initBindings(nb::module_& m)
nb::arg("output_sf") = std::nullopt, nb::arg("workspace_") = std::nullopt, nb::arg("sequence_length"),
nb::arg("host_past_key_value_lengths"), nb::arg("host_total_kv_lens"), nb::arg("context_lengths"),
nb::arg("host_context_lengths"), nb::arg("host_request_types"),
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt,
nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt,
nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt,
nb::arg("kv_scale_quant_orig") = std::nullopt, nb::arg("out_scale") = std::nullopt,
nb::arg("rotary_inv_freq") = std::nullopt, nb::arg("rotary_cos_sin") = std::nullopt,
nb::arg("latent_cache") = std::nullopt, nb::arg("q_pe") = std::nullopt,
nb::arg("block_ids_per_seq") = std::nullopt, nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"),
nb::arg("update_kv_cache"), nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"),
nb::arg("num_kv_heads"), nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt,
nb::arg("max_num_requests"), nb::arg("max_context_length"), nb::arg("attention_window_size"),
nb::arg("sink_token_length"), nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"),
nb::arg("q_scaling"), nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"),
nb::arg("rotary_embedding_base"), nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_pool_pointers") = std::nullopt,
nb::arg("host_kv_cache_pool_mapping") = std::nullopt, nb::arg("cache_indirection") = std::nullopt,
nb::arg("kv_scale_orig_quant") = std::nullopt, nb::arg("kv_scale_quant_orig") = std::nullopt,
nb::arg("out_scale") = std::nullopt, nb::arg("rotary_inv_freq") = std::nullopt,
nb::arg("rotary_cos_sin") = std::nullopt, nb::arg("latent_cache") = std::nullopt,
nb::arg("q_pe") = std::nullopt, nb::arg("block_ids_per_seq") = std::nullopt,
nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"), nb::arg("update_kv_cache"),
nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"), nb::arg("num_kv_heads"),
nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt, nb::arg("max_num_requests"),
nb::arg("max_context_length"), nb::arg("attention_window_size"), nb::arg("sink_token_length"),
nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"), nb::arg("q_scaling"),
nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"), nb::arg("rotary_embedding_base"),
nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
nb::arg("rotary_embedding_max_position_info"), nb::arg("use_paged_context_fmha"),
nb::arg("attention_input_type") = std::nullopt, nb::arg("is_mla_enable"),
nb::arg("chunked_prefill_buffer_batch_size") = std::nullopt, nb::arg("q_lora_rank") = std::nullopt,

View File

@ -858,7 +858,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
int max_blocks_per_sequence = 0;
kernels::KVBlockArray::DataType* block_offsets = nullptr;
kernels::KVBlockArray::DataType* host_block_offsets = nullptr;
void* host_primary_pool_pointer = nullptr;
void* host_secondary_pool_pointer = nullptr;
if (useKVCache() && mPagedKVCache)
@ -882,10 +881,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)])
+ poolOffset + seqOffset;
host_block_offsets
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS)])
+ poolOffset + seqOffset;
auto const* const typed_host_pool_pointers
= static_cast<char* const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_POINTERS)]);
@ -1046,7 +1041,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
common_enqueue_params.max_past_kv_length = max_context_kv_len;
EnqueueContextParams<T> enqueue_params{common_enqueue_params};
enqueue_params.attention_packed_mask = attention_packed_mask;
enqueue_params.host_block_offsets = host_block_offsets;
enqueue_params.batch_size = batch_size;
enqueue_params.mrope_rotary_cos_sin = mrope_rotary_cos_sin;
enqueue_params.total_kv_len = enqueue_params.num_tokens;

View File

@ -55,8 +55,7 @@ namespace tensorrt_llm::plugins
// all elements must be identical.
// 8. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or
// block_offsets [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
// 8.1 host_block_offsets [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
// 8.2 host_pool_pointers [2] if paged kv cache (optional)
// 8.1 host_pool_pointers [2] if paged kv cache (optional)
// 9. kv_cache_quantization_scale [1] (optional)
// 10. kv_cache_dequantization_scale [1] (optional)
// 11. attention_output_quantization_scale [1] (on device, optional)

View File

@ -42,19 +42,19 @@ void initBindings(pybind11::module_& m)
py::arg("output_sf") = std::nullopt, py::arg("workspace_") = std::nullopt, py::arg("sequence_length"),
py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"), py::arg("context_lengths"),
py::arg("host_context_lengths"), py::arg("host_request_types"),
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt,
py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt,
py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt,
py::arg("kv_scale_quant_orig") = std::nullopt, py::arg("out_scale") = std::nullopt,
py::arg("rotary_inv_freq") = std::nullopt, py::arg("rotary_cos_sin") = std::nullopt,
py::arg("latent_cache") = std::nullopt, py::arg("q_pe") = std::nullopt,
py::arg("block_ids_per_seq") = std::nullopt, py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"),
py::arg("update_kv_cache"), py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"),
py::arg("num_kv_heads"), py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt,
py::arg("max_num_requests"), py::arg("max_context_length"), py::arg("attention_window_size"),
py::arg("sink_token_length"), py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"),
py::arg("q_scaling"), py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"),
py::arg("rotary_embedding_base"), py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_pool_pointers") = std::nullopt,
py::arg("host_kv_cache_pool_mapping") = std::nullopt, py::arg("cache_indirection") = std::nullopt,
py::arg("kv_scale_orig_quant") = std::nullopt, py::arg("kv_scale_quant_orig") = std::nullopt,
py::arg("out_scale") = std::nullopt, py::arg("rotary_inv_freq") = std::nullopt,
py::arg("rotary_cos_sin") = std::nullopt, py::arg("latent_cache") = std::nullopt,
py::arg("q_pe") = std::nullopt, py::arg("block_ids_per_seq") = std::nullopt,
py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"), py::arg("update_kv_cache"),
py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"), py::arg("num_kv_heads"),
py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt, py::arg("max_num_requests"),
py::arg("max_context_length"), py::arg("attention_window_size"), py::arg("sink_token_length"),
py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"), py::arg("q_scaling"),
py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"), py::arg("rotary_embedding_base"),
py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"),
py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"),
py::arg("chunked_prefill_buffer_batch_size") = std::nullopt, py::arg("q_lora_rank") = std::nullopt,

View File

@ -75,7 +75,6 @@ public:
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::optional<torch::Tensor> kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
@ -136,7 +135,6 @@ public:
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::optional<torch::Tensor> kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
@ -289,7 +287,6 @@ public:
// Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same
// unless each layer has different attention window sizes.
// the kv_cache capacity.
int const max_attention_window_size = beam_width == 1 ? attention_window_size
: cache_indirection.has_value() ? cache_indirection.value().size(2)
: attention_window_size;
@ -310,10 +307,6 @@ public:
= static_cast<KVBlockArray::DataType*>(op.useKVCache() && kv_cache_block_offsets.has_value()
? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr()
: nullptr);
KVBlockArray::DataType* host_block_offsets
= static_cast<KVBlockArray::DataType*>(op.useKVCache() && host_kv_cache_block_offsets.has_value()
? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr()
: nullptr);
// The cache element size in bits.
int cache_elem_bits = op.getKvCacheElemSizeInBits<T>();
@ -463,7 +456,6 @@ public:
{
common_enqueue_params.input_seq_length = max_context_q_len;
AttentionOp::EnqueueContextParams<T> enqueue_params{common_enqueue_params};
enqueue_params.host_block_offsets = host_block_offsets;
enqueue_params.batch_size = num_seqs;
enqueue_params.k_ptr = k_ptr;
enqueue_params.v_ptr = v_ptr;
@ -606,19 +598,19 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<torch::Tensor> output_sf, std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,
@ -639,8 +631,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
{
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
// Use these tensors to infer if the attention is using KV cache
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value()
&& host_kv_cache_pool_pointers.has_value() && host_kv_cache_pool_mapping.has_value();
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_pool_pointers.has_value()
&& host_kv_cache_pool_mapping.has_value();
TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Only fused QKV is supported for non-MLA attention now");
TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now");
@ -894,13 +886,12 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
/*num_seqs=*/num_contexts, token_offset,
/*num_tokens=*/num_ctx_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv_or_q, k, v,
sequence_length, host_past_key_value_lengths, ctx_total_kv_len, context_lengths, host_context_lengths,
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_tensor_params, softmax_stats_tensor,
spec_decoding_tensor_params, attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices,
sparse_attn_offsets, sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens,
fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
}
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
@ -913,13 +904,12 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
/*num_seqs=*/num_generations, token_offset,
/*num_tokens=*/num_gen_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv_or_q, k, v,
sequence_length, host_past_key_value_lengths, gen_total_kv_len, context_lengths, host_context_lengths,
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_tensor_params, softmax_stats_tensor,
spec_decoding_tensor_params, attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices,
sparse_attn_offsets, sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens,
fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
}
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);

View File

@ -42,19 +42,19 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<torch::Tensor> output_sf, std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,

View File

@ -121,8 +121,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor host_context_lengths,
int64_t const num_contexts, std::optional<torch::Tensor> kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
std::optional<torch::Tensor> host_kv_cache_pool_mapping,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
torch::optional<torch::Tensor> kv_scale_orig_quant, // [1] q,k quant scale
torch::optional<torch::Tensor> kv_scale_quant_orig, // [1] bmm quant scale
torch::optional<torch::Tensor> out_scale, // [1] output quant scale
@ -147,8 +146,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
TLLM_CHECK_WITH_INFO(
host_kv_cache_pool_mapping.has_value(), "KV cache pool mapping is required for MLA generation.");
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value()
&& host_kv_cache_pool_pointers.has_value() && host_kv_cache_pool_mapping.has_value();
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_pool_pointers.has_value()
&& host_kv_cache_pool_mapping.has_value();
int32_t const num_seqs = host_context_lengths.size(0);
@ -331,7 +330,6 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
", Tensor host_context_lengths"
", int num_contexts"
", Tensor? kv_cache_block_offsets"
", Tensor? host_kv_cache_block_offsets"
", Tensor? host_kv_cache_pool_pointers"
", Tensor? host_kv_cache_pool_mapping"
", Tensor? kv_scale_orig_quant"

View File

@ -159,12 +159,11 @@ KVBlockArray createKVBlockArray(int num_contexts, int max_blocks_per_sequence, i
std::vector<torch::Tensor> loadPagedKVCacheForMLA(torch::ScalarType out_dtype, int64_t const num_contexts,
int64_t const num_ctx_cached_tokens, int64_t const max_ctx_cached_kv_len, torch::Tensor& cu_ctx_cached_kv_lens,
torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_block_offsets,
torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
int64_t const layer_idx, int64_t const lora_size, int64_t const rope_size, int64_t const tokens_per_block,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const quant_mode)
torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers,
torch::Tensor const& host_kv_cache_pool_mapping, torch::optional<torch::Tensor> kv_scale_orig_quant,
torch::optional<torch::Tensor> kv_scale_quant_orig, int64_t const layer_idx, int64_t const lora_size,
int64_t const rope_size, int64_t const tokens_per_block, int64_t const attention_window_size,
int64_t const sink_token_length, int64_t const beam_width, int64_t const quant_mode)
{
TORCH_CHECK(out_dtype == torch::kFloat16 || out_dtype == torch::kFloat32 || out_dtype == torch::kBFloat16,
"out_dtype only support float16, float32, bfloat16");
@ -348,11 +347,11 @@ void MLARopeAppendPagedKVAssignQ(torch::Tensor& q, torch::Tensor& latent_cache,
torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens,
int64_t const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, int64_t const head_num,
int64_t const nope_size, int64_t const rope_size, int64_t const lora_size,
torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_block_offsets,
torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
int64_t const layer_idx, int64_t const tokens_per_block, int64_t const attention_window_size,
int64_t const sink_token_length, int64_t const beam_width, int64_t const quant_mode)
torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers,
torch::Tensor const& host_kv_cache_pool_mapping, torch::optional<torch::Tensor> kv_scale_orig_quant,
torch::optional<torch::Tensor> kv_scale_quant_orig, int64_t const layer_idx, int64_t const tokens_per_block,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const quant_mode)
{
auto input_dtype = q.scalar_type();
TORCH_CHECK(input_dtype == torch::kFloat16 || input_dtype == torch::kFloat32 || input_dtype == torch::kBFloat16);
@ -482,7 +481,6 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
", int max_ctx_cached_kv_len"
", Tensor cu_ctx_cached_kv_lens"
", Tensor kv_cache_block_offsets"
", Tensor host_kv_cache_block_offsets"
", Tensor host_kv_cache_pool_pointers"
", Tensor host_kv_cache_pool_mapping"
", Tensor? kv_scale_orig_quant"
@ -550,7 +548,6 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
", int rope_size"
", int lora_size"
", Tensor kv_cache_block_offsets"
", Tensor host_kv_cache_block_offsets"
", Tensor host_kv_cache_pool_pointers"
", Tensor host_kv_cache_pool_mapping"
", Tensor? kv_scale_orig_quant"

View File

@ -1679,16 +1679,12 @@ class DSATrtllmAttention(TrtllmAttention):
max_seq_len = metadata.max_gen_seq_len
block_offsets = metadata.kv_cache_block_offsets[:, metadata.
num_contexts:]
host_block_offsets = metadata.host_kv_cache_block_offsets[:,
metadata.
num_contexts:]
else:
cached_token_indptr = metadata.ctx_cached_token_indptr
kv_indptr = metadata.ctx_kv_indptr
num_seqs = metadata.num_contexts
max_seq_len = metadata.max_ctx_seq_len
block_offsets = metadata.kv_cache_block_offsets
host_block_offsets = metadata.host_kv_cache_block_offsets
assert self.is_mla_enable and self.mla_params is not None
assert metadata.kv_cache_manager is not None
@ -1708,7 +1704,6 @@ class DSATrtllmAttention(TrtllmAttention):
self.mla_params.qk_rope_head_dim,
self.mla_params.kv_lora_rank,
block_offsets,
host_block_offsets,
metadata.kv_cache_manager.kv_cache_pool_pointers,
metadata.kv_cache_manager.kv_cache_pool_mapping,
self.kv_scale_orig_quant,

View File

@ -35,7 +35,6 @@ class TrtllmAttentionWrapper:
host_context_lengths: torch.Tensor
host_request_types: torch.Tensor
kv_cache_block_offsets: torch.Tensor
host_kv_cache_block_offsets: torch.Tensor
host_kv_cache_pool_pointers: torch.Tensor
host_kv_cache_pool_mapping: torch.Tensor
workspace: Optional[torch.Tensor]
@ -181,7 +180,6 @@ class TrtllmAttentionWrapper:
host_context_lengths: torch.Tensor = ...,
host_request_types: torch.Tensor = ...,
kv_cache_block_offsets: Optional[torch.Tensor] = None,
host_kv_cache_block_offsets: Optional[torch.Tensor] = None,
host_kv_cache_pool_pointers: Optional[torch.Tensor] = None,
host_kv_cache_pool_mapping: Optional[torch.Tensor] = None,
block_ids_per_seq: Optional[torch.Tensor] = None,
@ -242,8 +240,7 @@ class TrtllmAttentionWrapper:
context_lengths (torch.Tensor): The context-phase sequence length of each request with shape (batch_size) on GPU.
host_context_lengths (torch.Tensor): Same as context_lengths, but on CPU.
host_request_types (torch.Tensor): The tensor that indicates whether a request is in context or generation phase, with shape (batch_size) on CPU.
kv_cache_block_offsets (torch.Tensor): The offsets to the blocks inside KV cache pools on GPU, its shape is (num_pools, max_batch_size * max_beam_width, 2, max_blocks_per_sequence), one for each block. If kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention.
host_kv_cache_block_offsets (torch.Tensor): Same as kv_cache_block_offsets, but on CPU.
kv_cache_block_offsets (torch.Tensor): The offsets to the blocks inside KV cache pools on GPU, its shape is (num_pools, max_batch_size * max_beam_width, 2, max_blocks_per_sequence), one for each block. If kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention.
host_kv_cache_pool_pointers (torch.Tensor): The pointers to the KV cache pools on CPU, its shape is (num_pools, 2), one for primary pool in GPU memory, one for secondary pool in CPU memory.
host_kv_cache_pool_mapping (torch.Tensor): The index of the pool used by each attention layer on CPU, its shape is (num_local_attention_layers). The local attention layers mean all attention layers in the current PP stage in the pipeline parallelism case.
workspace (torch.Tensor): An optional workspace tensor on GPU.
@ -284,7 +281,6 @@ class TrtllmAttentionWrapper:
self.host_context_lengths = host_context_lengths
self.host_request_types = host_request_types
self.kv_cache_block_offsets = kv_cache_block_offsets
self.host_kv_cache_block_offsets = host_kv_cache_block_offsets
self.host_kv_cache_pool_pointers = host_kv_cache_pool_pointers
self.host_kv_cache_pool_mapping = host_kv_cache_pool_mapping
self.workspace = workspace
@ -511,7 +507,6 @@ class TrtllmAttentionWrapper:
self.host_context_lengths,
self.host_request_types,
self.kv_cache_block_offsets,
self.host_kv_cache_block_offsets,
self.host_kv_cache_pool_pointers,
self.host_kv_cache_pool_mapping,
self.cache_indirection,
@ -789,11 +784,6 @@ class TrtllmAttentionMetadata(AttentionMetadata):
dtype=torch.int32,
capture_graph=capture_graph,
)
self.host_kv_cache_block_offsets = torch.empty_like(
self.kv_cache_block_offsets,
device='cpu',
pin_memory=True,
)
self.block_ids_per_seq = None
self.kv_block_ids_per_seq = None
if self.enable_flash_mla:
@ -933,7 +923,6 @@ class TrtllmAttentionMetadata(AttentionMetadata):
# set params that are used in wrapper.plan()
self.kv_cache_block_offsets = None
self.host_kv_cache_block_offsets = None
self.block_ids_per_seq = None
prompt_lens = torch.tensor(
@ -986,19 +975,10 @@ class TrtllmAttentionMetadata(AttentionMetadata):
# kv block offsets
assert self.request_ids is not None
if self.kv_cache_manager is not None:
# Copy blocks for all context requests
self.kv_cache_manager.impl.copy_batch_block_offsets(
self.host_kv_cache_block_offsets,
self.request_ids[:self.num_contexts], 1, 0)
# Copy blocks for all generation requests
self.kv_cache_manager.impl.copy_batch_block_offsets(
self.host_kv_cache_block_offsets,
self.request_ids[self.num_contexts:], self.beam_width,
self.num_contexts)
for pool_idx in range(self.host_kv_cache_block_offsets.shape[0]):
self.kv_cache_block_offsets[pool_idx, :self.num_seqs].copy_(
self.host_kv_cache_block_offsets[pool_idx, :self.num_seqs],
non_blocking=True)
self.kv_cache_manager.copy_batch_block_offsets(
self.kv_cache_block_offsets, self.request_ids, self.beam_width,
self.num_contexts, self.num_seqs)
error_message = (
f"The max KV cache length of input sequences ({self.kv_lens[:self.num_seqs].max()}) "
f"exceeds the KV cache manager's maximum supported length "
@ -1723,7 +1703,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
host_context_lengths=metadata.prompt_lens_cpu_runtime,
host_request_types=metadata.host_request_types_runtime,
kv_cache_block_offsets=metadata.kv_cache_block_offsets,
host_kv_cache_block_offsets=metadata.host_kv_cache_block_offsets,
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping,
block_ids_per_seq=metadata.block_ids_per_seq,
@ -1847,7 +1826,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
metadata.max_ctx_kv_len,
metadata.ctx_kv_indptr,
metadata.kv_cache_block_offsets,
metadata.host_kv_cache_block_offsets,
metadata.kv_cache_manager.kv_cache_pool_pointers,
metadata.kv_cache_manager.kv_cache_pool_mapping,
self.kv_scale_orig_quant,
@ -1938,7 +1916,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
self.mla_params.qk_rope_head_dim,
self.mla_params.kv_lora_rank,
metadata.kv_cache_block_offsets,
metadata.host_kv_cache_block_offsets,
metadata.kv_cache_manager.kv_cache_pool_pointers,
metadata.kv_cache_manager.kv_cache_pool_mapping,
self.kv_scale_orig_quant,
@ -2052,7 +2029,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
metadata.prompt_lens_cpu_runtime, # host_context_lengths,
metadata.num_contexts,
metadata.kv_cache_block_offsets,
metadata.host_kv_cache_block_offsets,
metadata.kv_cache_manager.kv_cache_pool_pointers,
metadata.kv_cache_manager.kv_cache_pool_mapping,
self.kv_scale_orig_quant,

View File

@ -889,7 +889,6 @@ def _register_fake():
host_context_lengths: torch.Tensor,
num_contexts: int,
kv_cache_block_offsets: Optional[torch.Tensor],
host_kv_cache_block_offsets: Optional[torch.Tensor],
host_kv_cache_pool_pointers: Optional[torch.Tensor],
host_kv_cache_pool_mapping: Optional[torch.Tensor],
kv_scale_orig_quant: Optional[torch.Tensor],

View File

@ -407,6 +407,14 @@ class KVCacheManager(BaseResourceManager):
self.num_pools = self.impl.num_pools
self.max_blocks_per_seq = self.impl.max_blocks_per_seq
self.enable_block_reuse = kv_cache_config.enable_block_reuse
self.host_kv_cache_block_offsets = torch.empty(self.num_pools,
max_batch_size *
max_beam_width,
2,
self.max_blocks_per_seq,
dtype=torch.int32,
pin_memory=True,
device='cpu')
def shutdown(self):
self.impl.release_pools()
@ -650,7 +658,7 @@ class KVCacheManager(BaseResourceManager):
accepted_draft_token_offsets, packed_accepted_draft_tokens_indices, rewind_draft_token_separate_adjustments = self.locate_accepted_draft_tokens(
requests)
past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)]
if attn_metadata.kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_pool_pointers is not None and attn_metadata.host_kv_cache_pool_mapping is not None:
if attn_metadata.kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_pool_pointers is not None and attn_metadata.host_kv_cache_pool_mapping is not None:
use_paged_kv_cache = True
else:
use_paged_kv_cache = False
@ -1262,6 +1270,20 @@ class KVCacheManager(BaseResourceManager):
else:
return None
def copy_batch_block_offsets(self, dst_tensor: torch.Tensor,
request_ids: List[int], beam_width: int,
num_context: int, num_seqs: int):
self.impl.copy_batch_block_offsets(self.host_kv_cache_block_offsets,
request_ids[:num_context], 1, 0)
self.impl.copy_batch_block_offsets(self.host_kv_cache_block_offsets,
request_ids[num_context:],
beam_width, num_context)
for pool_idx in range(self.host_kv_cache_block_offsets.shape[0]):
dst_tensor[pool_idx, :num_seqs].copy_(
self.host_kv_cache_block_offsets[pool_idx, :num_seqs],
non_blocking=True)
def reset_reuse_state(self):
"""Reset the reuse state of the KV cache manager."""
self.impl.reset_reuse_state()

View File

@ -355,7 +355,7 @@ def _make_latent_cache_gen(
):
if rank == 0:
assert ref_attn_metadata is not None
kv_cache_block_offsets = ref_attn_metadata.host_kv_cache_block_offsets
kv_cache_block_offsets = ref_attn_metadata.kv_cache_manager.host_kv_cache_block_offsets
kv_buffer = ref_attn_metadata.kv_cache_manager.get_buffers(0)
ret = input_ctx_bs.new_empty(
(world_size - 1, input_ctx_bs.shape[0], mla.kv_lora_rank + mla.qk_rope_head_dim)

View File

@ -56,7 +56,7 @@ def test_kv_lens_runtime_with_eagle3_one_model():
mock_kv_cache_manager.max_blocks_per_seq = 16
mock_kv_cache_manager.max_batch_size = num_seqs
mock_kv_cache_manager.max_seq_len = 512 # Large enough to hold our test sequences
mock_kv_cache_manager.impl.copy_batch_block_offsets = MagicMock()
mock_kv_cache_manager.copy_batch_block_offsets = MagicMock()
attn_metadata = TrtllmAttentionMetadata(
max_num_requests=num_seqs,