From 58311b2345a4cc91cf0203865513676e6127149c Mon Sep 17 00:00:00 2001 From: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Date: Tue, 20 Jan 2026 16:08:59 +0800 Subject: [PATCH] [None][fix] Remove unused params in attn (#10652) Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com> --- cpp/tensorrt_llm/common/attentionOp.h | 2 - cpp/tensorrt_llm/nanobind/thop/bindings.cpp | 26 ++++---- .../gptAttentionPlugin/gptAttentionPlugin.cpp | 6 -- .../gptAttentionPlugin/gptAttentionPlugin.h | 3 +- cpp/tensorrt_llm/pybind/thop/bindings.cpp | 26 ++++---- cpp/tensorrt_llm/thop/attentionOp.cpp | 64 ++++++++----------- cpp/tensorrt_llm/thop/attentionOp.h | 26 ++++---- cpp/tensorrt_llm/thop/dsv3RopeOp.cpp | 8 +-- cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp | 23 +++---- .../_torch/attention_backend/sparse/dsa.py | 5 -- .../_torch/attention_backend/trtllm.py | 34 ++-------- .../_torch/custom_ops/cpp_custom_ops.py | 1 - .../_torch/pyexecutor/resource_manager.py | 24 ++++++- .../unittest/_torch/modules/test_mla_helix.py | 2 +- .../_torch/speculative/test_eagle3.py | 2 +- 15 files changed, 110 insertions(+), 142 deletions(-) diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index 4776b3d92d..eb21a42e9c 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -127,7 +127,6 @@ public: public: // Attention packed mask input (used by context FMHA). uint32_t const* attention_packed_mask = nullptr; - kernels::KVBlockArray::DataType* host_block_offsets = nullptr; int32_t batch_size = 0; float2 const* mrope_rotary_cos_sin = nullptr; @@ -182,7 +181,6 @@ public: ss << "context_buf_sf: " << this->context_buf_sf << std::endl; ss << "key_value_cache: " << (half*) this->key_value_cache << std::endl; ss << "block_offsets: " << this->block_offsets << std::endl; - ss << "host_block_offsets: " << this->host_block_offsets << std::endl; ss << "host_primary_pool_pointer: " << this->host_primary_pool_pointer << std::endl; ss << "host_secondary_pool_pointer: " << this->host_secondary_pool_pointer << std::endl; ss << "batch_size: " << this->batch_size << std::endl; diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp index 5f089c462b..9edfcc315a 100644 --- a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp @@ -42,19 +42,19 @@ void initBindings(nb::module_& m) nb::arg("output_sf") = std::nullopt, nb::arg("workspace_") = std::nullopt, nb::arg("sequence_length"), nb::arg("host_past_key_value_lengths"), nb::arg("host_total_kv_lens"), nb::arg("context_lengths"), nb::arg("host_context_lengths"), nb::arg("host_request_types"), - nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt, - nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt, - nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt, - nb::arg("kv_scale_quant_orig") = std::nullopt, nb::arg("out_scale") = std::nullopt, - nb::arg("rotary_inv_freq") = std::nullopt, nb::arg("rotary_cos_sin") = std::nullopt, - nb::arg("latent_cache") = std::nullopt, nb::arg("q_pe") = std::nullopt, - nb::arg("block_ids_per_seq") = std::nullopt, nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"), - nb::arg("update_kv_cache"), nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"), - nb::arg("num_kv_heads"), nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt, - nb::arg("max_num_requests"), nb::arg("max_context_length"), nb::arg("attention_window_size"), - nb::arg("sink_token_length"), nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"), - nb::arg("q_scaling"), nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"), - nb::arg("rotary_embedding_base"), nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"), + nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_pool_pointers") = std::nullopt, + nb::arg("host_kv_cache_pool_mapping") = std::nullopt, nb::arg("cache_indirection") = std::nullopt, + nb::arg("kv_scale_orig_quant") = std::nullopt, nb::arg("kv_scale_quant_orig") = std::nullopt, + nb::arg("out_scale") = std::nullopt, nb::arg("rotary_inv_freq") = std::nullopt, + nb::arg("rotary_cos_sin") = std::nullopt, nb::arg("latent_cache") = std::nullopt, + nb::arg("q_pe") = std::nullopt, nb::arg("block_ids_per_seq") = std::nullopt, + nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"), nb::arg("update_kv_cache"), + nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"), nb::arg("num_kv_heads"), + nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt, nb::arg("max_num_requests"), + nb::arg("max_context_length"), nb::arg("attention_window_size"), nb::arg("sink_token_length"), + nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"), nb::arg("q_scaling"), + nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"), nb::arg("rotary_embedding_base"), + nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"), nb::arg("rotary_embedding_max_position_info"), nb::arg("use_paged_context_fmha"), nb::arg("attention_input_type") = std::nullopt, nb::arg("is_mla_enable"), nb::arg("chunked_prefill_buffer_batch_size") = std::nullopt, nb::arg("q_lora_rank") = std::nullopt, diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index 34f97b9ad6..6f8c41c941 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -858,7 +858,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 int max_blocks_per_sequence = 0; kernels::KVBlockArray::DataType* block_offsets = nullptr; - kernels::KVBlockArray::DataType* host_block_offsets = nullptr; void* host_primary_pool_pointer = nullptr; void* host_secondary_pool_pointer = nullptr; if (useKVCache() && mPagedKVCache) @@ -882,10 +881,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 = reinterpret_cast(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)]) + poolOffset + seqOffset; - host_block_offsets - = reinterpret_cast(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS)]) - + poolOffset + seqOffset; - auto const* const typed_host_pool_pointers = static_cast(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_POINTERS)]); @@ -1046,7 +1041,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 common_enqueue_params.max_past_kv_length = max_context_kv_len; EnqueueContextParams enqueue_params{common_enqueue_params}; enqueue_params.attention_packed_mask = attention_packed_mask; - enqueue_params.host_block_offsets = host_block_offsets; enqueue_params.batch_size = batch_size; enqueue_params.mrope_rotary_cos_sin = mrope_rotary_cos_sin; enqueue_params.total_kv_len = enqueue_params.num_tokens; diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h index 13a3f0ecc6..3e34703c62 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h @@ -55,8 +55,7 @@ namespace tensorrt_llm::plugins // all elements must be identical. // 8. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or // block_offsets [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional) -// 8.1 host_block_offsets [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional) -// 8.2 host_pool_pointers [2] if paged kv cache (optional) +// 8.1 host_pool_pointers [2] if paged kv cache (optional) // 9. kv_cache_quantization_scale [1] (optional) // 10. kv_cache_dequantization_scale [1] (optional) // 11. attention_output_quantization_scale [1] (on device, optional) diff --git a/cpp/tensorrt_llm/pybind/thop/bindings.cpp b/cpp/tensorrt_llm/pybind/thop/bindings.cpp index f1469927ce..8ecf4b1259 100644 --- a/cpp/tensorrt_llm/pybind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/thop/bindings.cpp @@ -42,19 +42,19 @@ void initBindings(pybind11::module_& m) py::arg("output_sf") = std::nullopt, py::arg("workspace_") = std::nullopt, py::arg("sequence_length"), py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"), py::arg("context_lengths"), py::arg("host_context_lengths"), py::arg("host_request_types"), - py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt, - py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt, - py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt, - py::arg("kv_scale_quant_orig") = std::nullopt, py::arg("out_scale") = std::nullopt, - py::arg("rotary_inv_freq") = std::nullopt, py::arg("rotary_cos_sin") = std::nullopt, - py::arg("latent_cache") = std::nullopt, py::arg("q_pe") = std::nullopt, - py::arg("block_ids_per_seq") = std::nullopt, py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"), - py::arg("update_kv_cache"), py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"), - py::arg("num_kv_heads"), py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt, - py::arg("max_num_requests"), py::arg("max_context_length"), py::arg("attention_window_size"), - py::arg("sink_token_length"), py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"), - py::arg("q_scaling"), py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"), - py::arg("rotary_embedding_base"), py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"), + py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_pool_pointers") = std::nullopt, + py::arg("host_kv_cache_pool_mapping") = std::nullopt, py::arg("cache_indirection") = std::nullopt, + py::arg("kv_scale_orig_quant") = std::nullopt, py::arg("kv_scale_quant_orig") = std::nullopt, + py::arg("out_scale") = std::nullopt, py::arg("rotary_inv_freq") = std::nullopt, + py::arg("rotary_cos_sin") = std::nullopt, py::arg("latent_cache") = std::nullopt, + py::arg("q_pe") = std::nullopt, py::arg("block_ids_per_seq") = std::nullopt, + py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"), py::arg("update_kv_cache"), + py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"), py::arg("num_kv_heads"), + py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt, py::arg("max_num_requests"), + py::arg("max_context_length"), py::arg("attention_window_size"), py::arg("sink_token_length"), + py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"), py::arg("q_scaling"), + py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"), py::arg("rotary_embedding_base"), + py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"), py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"), py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"), py::arg("chunked_prefill_buffer_batch_size") = std::nullopt, py::arg("q_lora_rank") = std::nullopt, diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 688732b383..211b965bc7 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -75,7 +75,6 @@ public: torch::optional k, torch::optional v, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::optional kv_cache_block_offsets, - torch::optional host_kv_cache_block_offsets, torch::optional host_kv_cache_pool_pointers, torch::optional host_kv_cache_pool_mapping, torch::optional cache_indirection, torch::optional kv_scale_orig_quant, torch::optional kv_scale_quant_orig, @@ -136,7 +135,6 @@ public: torch::optional k, torch::optional v, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::optional kv_cache_block_offsets, - torch::optional host_kv_cache_block_offsets, torch::optional host_kv_cache_pool_pointers, torch::optional host_kv_cache_pool_mapping, torch::optional cache_indirection, torch::optional kv_scale_orig_quant, torch::optional kv_scale_quant_orig, @@ -289,7 +287,6 @@ public: // Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same // unless each layer has different attention window sizes. - // the kv_cache capacity. int const max_attention_window_size = beam_width == 1 ? attention_window_size : cache_indirection.has_value() ? cache_indirection.value().size(2) : attention_window_size; @@ -310,10 +307,6 @@ public: = static_cast(op.useKVCache() && kv_cache_block_offsets.has_value() ? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr); - KVBlockArray::DataType* host_block_offsets - = static_cast(op.useKVCache() && host_kv_cache_block_offsets.has_value() - ? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() - : nullptr); // The cache element size in bits. int cache_elem_bits = op.getKvCacheElemSizeInBits(); @@ -463,7 +456,6 @@ public: { common_enqueue_params.input_seq_length = max_context_q_len; AttentionOp::EnqueueContextParams enqueue_params{common_enqueue_params}; - enqueue_params.host_block_offsets = host_block_offsets; enqueue_params.batch_size = num_seqs; enqueue_params.k_ptr = k_ptr; enqueue_params.v_ptr = v_ptr; @@ -606,19 +598,19 @@ void attention(torch::Tensor q, std::optional k, std::optional output_sf, std::optional workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types, - std::optional kv_cache_block_offsets, std::optional host_kv_cache_block_offsets, - std::optional host_kv_cache_pool_pointers, std::optional host_kv_cache_pool_mapping, - std::optional cache_indirection, std::optional kv_scale_orig_quant, - std::optional kv_scale_quant_orig, std::optional out_scale, - std::optional rotary_inv_freq, std::optional rotary_cos_sin, - std::optional latent_cache, std::optional q_pe, - std::optional block_ids_per_seq, std::optional attention_sinks, - bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq, - int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size, - std::optional const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length, - int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, - int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type, - int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type, + std::optional kv_cache_block_offsets, std::optional host_kv_cache_pool_pointers, + std::optional host_kv_cache_pool_mapping, std::optional cache_indirection, + std::optional kv_scale_orig_quant, std::optional kv_scale_quant_orig, + std::optional out_scale, std::optional rotary_inv_freq, + std::optional rotary_cos_sin, std::optional latent_cache, + std::optional q_pe, std::optional block_ids_per_seq, + std::optional attention_sinks, bool const is_fused_qkv, bool const update_kv_cache, + int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads, + int64_t const num_kv_heads, int64_t const head_size, std::optional const tokens_per_block, + int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size, + int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode, + double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim, + double const rotary_embedding_base, int64_t const rotary_embedding_scale_type, std::vector rotary_embedding_scales, std::vector rotary_embedding_max_position_info, bool const use_paged_context_fmha, std::optional attention_input_type, bool is_mla_enable, std::optional chunked_prefill_buffer_batch_size, std::optional q_lora_rank, @@ -639,8 +631,8 @@ void attention(torch::Tensor q, std::optional k, std::optional k, std::optional 0) && (attn_input_type != AttentionInputType::ContextOnly)) @@ -913,13 +904,12 @@ void attention(torch::Tensor q, std::optional k, std::optional k, std::optional output_sf, std::optional workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types, - std::optional kv_cache_block_offsets, std::optional host_kv_cache_block_offsets, - std::optional host_kv_cache_pool_pointers, std::optional host_kv_cache_pool_mapping, - std::optional cache_indirection, std::optional kv_scale_orig_quant, - std::optional kv_scale_quant_orig, std::optional out_scale, - std::optional rotary_inv_freq, std::optional rotary_cos_sin, - std::optional latent_cache, std::optional q_pe, - std::optional block_ids_per_seq, std::optional attention_sinks, - bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq, - int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size, - std::optional const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length, - int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, - int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type, - int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type, + std::optional kv_cache_block_offsets, std::optional host_kv_cache_pool_pointers, + std::optional host_kv_cache_pool_mapping, std::optional cache_indirection, + std::optional kv_scale_orig_quant, std::optional kv_scale_quant_orig, + std::optional out_scale, std::optional rotary_inv_freq, + std::optional rotary_cos_sin, std::optional latent_cache, + std::optional q_pe, std::optional block_ids_per_seq, + std::optional attention_sinks, bool const is_fused_qkv, bool const update_kv_cache, + int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads, + int64_t const num_kv_heads, int64_t const head_size, std::optional const tokens_per_block, + int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size, + int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode, + double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim, + double const rotary_embedding_base, int64_t const rotary_embedding_scale_type, std::vector rotary_embedding_scales, std::vector rotary_embedding_max_position_info, bool const use_paged_context_fmha, std::optional attention_input_type, bool is_mla_enable, std::optional chunked_prefill_buffer_batch_size, std::optional q_lora_rank, diff --git a/cpp/tensorrt_llm/thop/dsv3RopeOp.cpp b/cpp/tensorrt_llm/thop/dsv3RopeOp.cpp index ff28f2004f..a27c394c40 100644 --- a/cpp/tensorrt_llm/thop/dsv3RopeOp.cpp +++ b/cpp/tensorrt_llm/thop/dsv3RopeOp.cpp @@ -121,8 +121,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim + std::optional mla_bmm2_scale, std::optional quant_q_buffer, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor host_context_lengths, int64_t const num_contexts, std::optional kv_cache_block_offsets, - std::optional host_kv_cache_block_offsets, std::optional host_kv_cache_pool_pointers, - std::optional host_kv_cache_pool_mapping, + std::optional host_kv_cache_pool_pointers, std::optional host_kv_cache_pool_mapping, torch::optional kv_scale_orig_quant, // [1] q,k quant scale torch::optional kv_scale_quant_orig, // [1] bmm quant scale torch::optional out_scale, // [1] output quant scale @@ -147,8 +146,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim + TLLM_CHECK_WITH_INFO( host_kv_cache_pool_mapping.has_value(), "KV cache pool mapping is required for MLA generation."); - bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value() - && host_kv_cache_pool_pointers.has_value() && host_kv_cache_pool_mapping.has_value(); + bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_pool_pointers.has_value() + && host_kv_cache_pool_mapping.has_value(); int32_t const num_seqs = host_context_lengths.size(0); @@ -331,7 +330,6 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) ", Tensor host_context_lengths" ", int num_contexts" ", Tensor? kv_cache_block_offsets" - ", Tensor? host_kv_cache_block_offsets" ", Tensor? host_kv_cache_pool_pointers" ", Tensor? host_kv_cache_pool_mapping" ", Tensor? kv_scale_orig_quant" diff --git a/cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp b/cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp index 171f0d1522..7b7d93324f 100644 --- a/cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp +++ b/cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp @@ -159,12 +159,11 @@ KVBlockArray createKVBlockArray(int num_contexts, int max_blocks_per_sequence, i std::vector loadPagedKVCacheForMLA(torch::ScalarType out_dtype, int64_t const num_contexts, int64_t const num_ctx_cached_tokens, int64_t const max_ctx_cached_kv_len, torch::Tensor& cu_ctx_cached_kv_lens, - torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_block_offsets, - torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping, - torch::optional kv_scale_orig_quant, torch::optional kv_scale_quant_orig, - int64_t const layer_idx, int64_t const lora_size, int64_t const rope_size, int64_t const tokens_per_block, - int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, - int64_t const quant_mode) + torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers, + torch::Tensor const& host_kv_cache_pool_mapping, torch::optional kv_scale_orig_quant, + torch::optional kv_scale_quant_orig, int64_t const layer_idx, int64_t const lora_size, + int64_t const rope_size, int64_t const tokens_per_block, int64_t const attention_window_size, + int64_t const sink_token_length, int64_t const beam_width, int64_t const quant_mode) { TORCH_CHECK(out_dtype == torch::kFloat16 || out_dtype == torch::kFloat32 || out_dtype == torch::kBFloat16, "out_dtype only support float16, float32, bfloat16"); @@ -348,11 +347,11 @@ void MLARopeAppendPagedKVAssignQ(torch::Tensor& q, torch::Tensor& latent_cache, torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens, int64_t const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, int64_t const head_num, int64_t const nope_size, int64_t const rope_size, int64_t const lora_size, - torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_block_offsets, - torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping, - torch::optional kv_scale_orig_quant, torch::optional kv_scale_quant_orig, - int64_t const layer_idx, int64_t const tokens_per_block, int64_t const attention_window_size, - int64_t const sink_token_length, int64_t const beam_width, int64_t const quant_mode) + torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers, + torch::Tensor const& host_kv_cache_pool_mapping, torch::optional kv_scale_orig_quant, + torch::optional kv_scale_quant_orig, int64_t const layer_idx, int64_t const tokens_per_block, + int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, + int64_t const quant_mode) { auto input_dtype = q.scalar_type(); TORCH_CHECK(input_dtype == torch::kFloat16 || input_dtype == torch::kFloat32 || input_dtype == torch::kBFloat16); @@ -482,7 +481,6 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) ", int max_ctx_cached_kv_len" ", Tensor cu_ctx_cached_kv_lens" ", Tensor kv_cache_block_offsets" - ", Tensor host_kv_cache_block_offsets" ", Tensor host_kv_cache_pool_pointers" ", Tensor host_kv_cache_pool_mapping" ", Tensor? kv_scale_orig_quant" @@ -550,7 +548,6 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) ", int rope_size" ", int lora_size" ", Tensor kv_cache_block_offsets" - ", Tensor host_kv_cache_block_offsets" ", Tensor host_kv_cache_pool_pointers" ", Tensor host_kv_cache_pool_mapping" ", Tensor? kv_scale_orig_quant" diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 9ae60d0c1f..d24bc51ece 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1679,16 +1679,12 @@ class DSATrtllmAttention(TrtllmAttention): max_seq_len = metadata.max_gen_seq_len block_offsets = metadata.kv_cache_block_offsets[:, metadata. num_contexts:] - host_block_offsets = metadata.host_kv_cache_block_offsets[:, - metadata. - num_contexts:] else: cached_token_indptr = metadata.ctx_cached_token_indptr kv_indptr = metadata.ctx_kv_indptr num_seqs = metadata.num_contexts max_seq_len = metadata.max_ctx_seq_len block_offsets = metadata.kv_cache_block_offsets - host_block_offsets = metadata.host_kv_cache_block_offsets assert self.is_mla_enable and self.mla_params is not None assert metadata.kv_cache_manager is not None @@ -1708,7 +1704,6 @@ class DSATrtllmAttention(TrtllmAttention): self.mla_params.qk_rope_head_dim, self.mla_params.kv_lora_rank, block_offsets, - host_block_offsets, metadata.kv_cache_manager.kv_cache_pool_pointers, metadata.kv_cache_manager.kv_cache_pool_mapping, self.kv_scale_orig_quant, diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index c54776d233..3ba32dc9ac 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -35,7 +35,6 @@ class TrtllmAttentionWrapper: host_context_lengths: torch.Tensor host_request_types: torch.Tensor kv_cache_block_offsets: torch.Tensor - host_kv_cache_block_offsets: torch.Tensor host_kv_cache_pool_pointers: torch.Tensor host_kv_cache_pool_mapping: torch.Tensor workspace: Optional[torch.Tensor] @@ -181,7 +180,6 @@ class TrtllmAttentionWrapper: host_context_lengths: torch.Tensor = ..., host_request_types: torch.Tensor = ..., kv_cache_block_offsets: Optional[torch.Tensor] = None, - host_kv_cache_block_offsets: Optional[torch.Tensor] = None, host_kv_cache_pool_pointers: Optional[torch.Tensor] = None, host_kv_cache_pool_mapping: Optional[torch.Tensor] = None, block_ids_per_seq: Optional[torch.Tensor] = None, @@ -242,8 +240,7 @@ class TrtllmAttentionWrapper: context_lengths (torch.Tensor): The context-phase sequence length of each request with shape (batch_size) on GPU. host_context_lengths (torch.Tensor): Same as context_lengths, but on CPU. host_request_types (torch.Tensor): The tensor that indicates whether a request is in context or generation phase, with shape (batch_size) on CPU. - kv_cache_block_offsets (torch.Tensor): The offsets to the blocks inside KV cache pools on GPU, its shape is (num_pools, max_batch_size * max_beam_width, 2, max_blocks_per_sequence), one for each block. If kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention. - host_kv_cache_block_offsets (torch.Tensor): Same as kv_cache_block_offsets, but on CPU. + kv_cache_block_offsets (torch.Tensor): The offsets to the blocks inside KV cache pools on GPU, its shape is (num_pools, max_batch_size * max_beam_width, 2, max_blocks_per_sequence), one for each block. If kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention. host_kv_cache_pool_pointers (torch.Tensor): The pointers to the KV cache pools on CPU, its shape is (num_pools, 2), one for primary pool in GPU memory, one for secondary pool in CPU memory. host_kv_cache_pool_mapping (torch.Tensor): The index of the pool used by each attention layer on CPU, its shape is (num_local_attention_layers). The local attention layers mean all attention layers in the current PP stage in the pipeline parallelism case. workspace (torch.Tensor): An optional workspace tensor on GPU. @@ -284,7 +281,6 @@ class TrtllmAttentionWrapper: self.host_context_lengths = host_context_lengths self.host_request_types = host_request_types self.kv_cache_block_offsets = kv_cache_block_offsets - self.host_kv_cache_block_offsets = host_kv_cache_block_offsets self.host_kv_cache_pool_pointers = host_kv_cache_pool_pointers self.host_kv_cache_pool_mapping = host_kv_cache_pool_mapping self.workspace = workspace @@ -511,7 +507,6 @@ class TrtllmAttentionWrapper: self.host_context_lengths, self.host_request_types, self.kv_cache_block_offsets, - self.host_kv_cache_block_offsets, self.host_kv_cache_pool_pointers, self.host_kv_cache_pool_mapping, self.cache_indirection, @@ -789,11 +784,6 @@ class TrtllmAttentionMetadata(AttentionMetadata): dtype=torch.int32, capture_graph=capture_graph, ) - self.host_kv_cache_block_offsets = torch.empty_like( - self.kv_cache_block_offsets, - device='cpu', - pin_memory=True, - ) self.block_ids_per_seq = None self.kv_block_ids_per_seq = None if self.enable_flash_mla: @@ -933,7 +923,6 @@ class TrtllmAttentionMetadata(AttentionMetadata): # set params that are used in wrapper.plan() self.kv_cache_block_offsets = None - self.host_kv_cache_block_offsets = None self.block_ids_per_seq = None prompt_lens = torch.tensor( @@ -986,19 +975,10 @@ class TrtllmAttentionMetadata(AttentionMetadata): # kv block offsets assert self.request_ids is not None if self.kv_cache_manager is not None: - # Copy blocks for all context requests - self.kv_cache_manager.impl.copy_batch_block_offsets( - self.host_kv_cache_block_offsets, - self.request_ids[:self.num_contexts], 1, 0) - # Copy blocks for all generation requests - self.kv_cache_manager.impl.copy_batch_block_offsets( - self.host_kv_cache_block_offsets, - self.request_ids[self.num_contexts:], self.beam_width, - self.num_contexts) - for pool_idx in range(self.host_kv_cache_block_offsets.shape[0]): - self.kv_cache_block_offsets[pool_idx, :self.num_seqs].copy_( - self.host_kv_cache_block_offsets[pool_idx, :self.num_seqs], - non_blocking=True) + self.kv_cache_manager.copy_batch_block_offsets( + self.kv_cache_block_offsets, self.request_ids, self.beam_width, + self.num_contexts, self.num_seqs) + error_message = ( f"The max KV cache length of input sequences ({self.kv_lens[:self.num_seqs].max()}) " f"exceeds the KV cache manager's maximum supported length " @@ -1723,7 +1703,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): host_context_lengths=metadata.prompt_lens_cpu_runtime, host_request_types=metadata.host_request_types_runtime, kv_cache_block_offsets=metadata.kv_cache_block_offsets, - host_kv_cache_block_offsets=metadata.host_kv_cache_block_offsets, host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers, host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping, block_ids_per_seq=metadata.block_ids_per_seq, @@ -1847,7 +1826,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): metadata.max_ctx_kv_len, metadata.ctx_kv_indptr, metadata.kv_cache_block_offsets, - metadata.host_kv_cache_block_offsets, metadata.kv_cache_manager.kv_cache_pool_pointers, metadata.kv_cache_manager.kv_cache_pool_mapping, self.kv_scale_orig_quant, @@ -1938,7 +1916,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): self.mla_params.qk_rope_head_dim, self.mla_params.kv_lora_rank, metadata.kv_cache_block_offsets, - metadata.host_kv_cache_block_offsets, metadata.kv_cache_manager.kv_cache_pool_pointers, metadata.kv_cache_manager.kv_cache_pool_mapping, self.kv_scale_orig_quant, @@ -2052,7 +2029,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): metadata.prompt_lens_cpu_runtime, # host_context_lengths, metadata.num_contexts, metadata.kv_cache_block_offsets, - metadata.host_kv_cache_block_offsets, metadata.kv_cache_manager.kv_cache_pool_pointers, metadata.kv_cache_manager.kv_cache_pool_mapping, self.kv_scale_orig_quant, diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index efbbac39e3..4fc53b676b 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -889,7 +889,6 @@ def _register_fake(): host_context_lengths: torch.Tensor, num_contexts: int, kv_cache_block_offsets: Optional[torch.Tensor], - host_kv_cache_block_offsets: Optional[torch.Tensor], host_kv_cache_pool_pointers: Optional[torch.Tensor], host_kv_cache_pool_mapping: Optional[torch.Tensor], kv_scale_orig_quant: Optional[torch.Tensor], diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index fd7a6d7176..477a9db00a 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -407,6 +407,14 @@ class KVCacheManager(BaseResourceManager): self.num_pools = self.impl.num_pools self.max_blocks_per_seq = self.impl.max_blocks_per_seq self.enable_block_reuse = kv_cache_config.enable_block_reuse + self.host_kv_cache_block_offsets = torch.empty(self.num_pools, + max_batch_size * + max_beam_width, + 2, + self.max_blocks_per_seq, + dtype=torch.int32, + pin_memory=True, + device='cpu') def shutdown(self): self.impl.release_pools() @@ -650,7 +658,7 @@ class KVCacheManager(BaseResourceManager): accepted_draft_token_offsets, packed_accepted_draft_tokens_indices, rewind_draft_token_separate_adjustments = self.locate_accepted_draft_tokens( requests) past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)] - if attn_metadata.kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_pool_pointers is not None and attn_metadata.host_kv_cache_pool_mapping is not None: + if attn_metadata.kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_pool_pointers is not None and attn_metadata.host_kv_cache_pool_mapping is not None: use_paged_kv_cache = True else: use_paged_kv_cache = False @@ -1262,6 +1270,20 @@ class KVCacheManager(BaseResourceManager): else: return None + def copy_batch_block_offsets(self, dst_tensor: torch.Tensor, + request_ids: List[int], beam_width: int, + num_context: int, num_seqs: int): + self.impl.copy_batch_block_offsets(self.host_kv_cache_block_offsets, + request_ids[:num_context], 1, 0) + self.impl.copy_batch_block_offsets(self.host_kv_cache_block_offsets, + request_ids[num_context:], + beam_width, num_context) + + for pool_idx in range(self.host_kv_cache_block_offsets.shape[0]): + dst_tensor[pool_idx, :num_seqs].copy_( + self.host_kv_cache_block_offsets[pool_idx, :num_seqs], + non_blocking=True) + def reset_reuse_state(self): """Reset the reuse state of the KV cache manager.""" self.impl.reset_reuse_state() diff --git a/tests/unittest/_torch/modules/test_mla_helix.py b/tests/unittest/_torch/modules/test_mla_helix.py index fc7aedf10e..a6a0d5202e 100644 --- a/tests/unittest/_torch/modules/test_mla_helix.py +++ b/tests/unittest/_torch/modules/test_mla_helix.py @@ -355,7 +355,7 @@ def _make_latent_cache_gen( ): if rank == 0: assert ref_attn_metadata is not None - kv_cache_block_offsets = ref_attn_metadata.host_kv_cache_block_offsets + kv_cache_block_offsets = ref_attn_metadata.kv_cache_manager.host_kv_cache_block_offsets kv_buffer = ref_attn_metadata.kv_cache_manager.get_buffers(0) ret = input_ctx_bs.new_empty( (world_size - 1, input_ctx_bs.shape[0], mla.kv_lora_rank + mla.qk_rope_head_dim) diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index c2d4cf50f4..578ff41cb3 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -56,7 +56,7 @@ def test_kv_lens_runtime_with_eagle3_one_model(): mock_kv_cache_manager.max_blocks_per_seq = 16 mock_kv_cache_manager.max_batch_size = num_seqs mock_kv_cache_manager.max_seq_len = 512 # Large enough to hold our test sequences - mock_kv_cache_manager.impl.copy_batch_block_offsets = MagicMock() + mock_kv_cache_manager.copy_batch_block_offsets = MagicMock() attn_metadata = TrtllmAttentionMetadata( max_num_requests=num_seqs,