mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-06 03:01:50 +08:00
[None][fix] Remove unused params in attn (#10652)
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
parent
47e0ec2527
commit
58311b2345
@ -127,7 +127,6 @@ public:
|
||||
public:
|
||||
// Attention packed mask input (used by context FMHA).
|
||||
uint32_t const* attention_packed_mask = nullptr;
|
||||
kernels::KVBlockArray::DataType* host_block_offsets = nullptr;
|
||||
int32_t batch_size = 0;
|
||||
float2 const* mrope_rotary_cos_sin = nullptr;
|
||||
|
||||
@ -182,7 +181,6 @@ public:
|
||||
ss << "context_buf_sf: " << this->context_buf_sf << std::endl;
|
||||
ss << "key_value_cache: " << (half*) this->key_value_cache << std::endl;
|
||||
ss << "block_offsets: " << this->block_offsets << std::endl;
|
||||
ss << "host_block_offsets: " << this->host_block_offsets << std::endl;
|
||||
ss << "host_primary_pool_pointer: " << this->host_primary_pool_pointer << std::endl;
|
||||
ss << "host_secondary_pool_pointer: " << this->host_secondary_pool_pointer << std::endl;
|
||||
ss << "batch_size: " << this->batch_size << std::endl;
|
||||
|
||||
@ -42,19 +42,19 @@ void initBindings(nb::module_& m)
|
||||
nb::arg("output_sf") = std::nullopt, nb::arg("workspace_") = std::nullopt, nb::arg("sequence_length"),
|
||||
nb::arg("host_past_key_value_lengths"), nb::arg("host_total_kv_lens"), nb::arg("context_lengths"),
|
||||
nb::arg("host_context_lengths"), nb::arg("host_request_types"),
|
||||
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt,
|
||||
nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt,
|
||||
nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt,
|
||||
nb::arg("kv_scale_quant_orig") = std::nullopt, nb::arg("out_scale") = std::nullopt,
|
||||
nb::arg("rotary_inv_freq") = std::nullopt, nb::arg("rotary_cos_sin") = std::nullopt,
|
||||
nb::arg("latent_cache") = std::nullopt, nb::arg("q_pe") = std::nullopt,
|
||||
nb::arg("block_ids_per_seq") = std::nullopt, nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"),
|
||||
nb::arg("update_kv_cache"), nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"),
|
||||
nb::arg("num_kv_heads"), nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt,
|
||||
nb::arg("max_num_requests"), nb::arg("max_context_length"), nb::arg("attention_window_size"),
|
||||
nb::arg("sink_token_length"), nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"),
|
||||
nb::arg("q_scaling"), nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"),
|
||||
nb::arg("rotary_embedding_base"), nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
|
||||
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_pool_pointers") = std::nullopt,
|
||||
nb::arg("host_kv_cache_pool_mapping") = std::nullopt, nb::arg("cache_indirection") = std::nullopt,
|
||||
nb::arg("kv_scale_orig_quant") = std::nullopt, nb::arg("kv_scale_quant_orig") = std::nullopt,
|
||||
nb::arg("out_scale") = std::nullopt, nb::arg("rotary_inv_freq") = std::nullopt,
|
||||
nb::arg("rotary_cos_sin") = std::nullopt, nb::arg("latent_cache") = std::nullopt,
|
||||
nb::arg("q_pe") = std::nullopt, nb::arg("block_ids_per_seq") = std::nullopt,
|
||||
nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"), nb::arg("update_kv_cache"),
|
||||
nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"), nb::arg("num_kv_heads"),
|
||||
nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt, nb::arg("max_num_requests"),
|
||||
nb::arg("max_context_length"), nb::arg("attention_window_size"), nb::arg("sink_token_length"),
|
||||
nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"), nb::arg("q_scaling"),
|
||||
nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"), nb::arg("rotary_embedding_base"),
|
||||
nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
|
||||
nb::arg("rotary_embedding_max_position_info"), nb::arg("use_paged_context_fmha"),
|
||||
nb::arg("attention_input_type") = std::nullopt, nb::arg("is_mla_enable"),
|
||||
nb::arg("chunked_prefill_buffer_batch_size") = std::nullopt, nb::arg("q_lora_rank") = std::nullopt,
|
||||
|
||||
@ -858,7 +858,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
|
||||
int max_blocks_per_sequence = 0;
|
||||
kernels::KVBlockArray::DataType* block_offsets = nullptr;
|
||||
kernels::KVBlockArray::DataType* host_block_offsets = nullptr;
|
||||
void* host_primary_pool_pointer = nullptr;
|
||||
void* host_secondary_pool_pointer = nullptr;
|
||||
if (useKVCache() && mPagedKVCache)
|
||||
@ -882,10 +881,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)])
|
||||
+ poolOffset + seqOffset;
|
||||
|
||||
host_block_offsets
|
||||
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS)])
|
||||
+ poolOffset + seqOffset;
|
||||
|
||||
auto const* const typed_host_pool_pointers
|
||||
= static_cast<char* const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_POINTERS)]);
|
||||
|
||||
@ -1046,7 +1041,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
common_enqueue_params.max_past_kv_length = max_context_kv_len;
|
||||
EnqueueContextParams<T> enqueue_params{common_enqueue_params};
|
||||
enqueue_params.attention_packed_mask = attention_packed_mask;
|
||||
enqueue_params.host_block_offsets = host_block_offsets;
|
||||
enqueue_params.batch_size = batch_size;
|
||||
enqueue_params.mrope_rotary_cos_sin = mrope_rotary_cos_sin;
|
||||
enqueue_params.total_kv_len = enqueue_params.num_tokens;
|
||||
|
||||
@ -55,8 +55,7 @@ namespace tensorrt_llm::plugins
|
||||
// all elements must be identical.
|
||||
// 8. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or
|
||||
// block_offsets [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
|
||||
// 8.1 host_block_offsets [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
|
||||
// 8.2 host_pool_pointers [2] if paged kv cache (optional)
|
||||
// 8.1 host_pool_pointers [2] if paged kv cache (optional)
|
||||
// 9. kv_cache_quantization_scale [1] (optional)
|
||||
// 10. kv_cache_dequantization_scale [1] (optional)
|
||||
// 11. attention_output_quantization_scale [1] (on device, optional)
|
||||
|
||||
@ -42,19 +42,19 @@ void initBindings(pybind11::module_& m)
|
||||
py::arg("output_sf") = std::nullopt, py::arg("workspace_") = std::nullopt, py::arg("sequence_length"),
|
||||
py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"), py::arg("context_lengths"),
|
||||
py::arg("host_context_lengths"), py::arg("host_request_types"),
|
||||
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt,
|
||||
py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt,
|
||||
py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt,
|
||||
py::arg("kv_scale_quant_orig") = std::nullopt, py::arg("out_scale") = std::nullopt,
|
||||
py::arg("rotary_inv_freq") = std::nullopt, py::arg("rotary_cos_sin") = std::nullopt,
|
||||
py::arg("latent_cache") = std::nullopt, py::arg("q_pe") = std::nullopt,
|
||||
py::arg("block_ids_per_seq") = std::nullopt, py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"),
|
||||
py::arg("update_kv_cache"), py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"),
|
||||
py::arg("num_kv_heads"), py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt,
|
||||
py::arg("max_num_requests"), py::arg("max_context_length"), py::arg("attention_window_size"),
|
||||
py::arg("sink_token_length"), py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"),
|
||||
py::arg("q_scaling"), py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"),
|
||||
py::arg("rotary_embedding_base"), py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
|
||||
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_pool_pointers") = std::nullopt,
|
||||
py::arg("host_kv_cache_pool_mapping") = std::nullopt, py::arg("cache_indirection") = std::nullopt,
|
||||
py::arg("kv_scale_orig_quant") = std::nullopt, py::arg("kv_scale_quant_orig") = std::nullopt,
|
||||
py::arg("out_scale") = std::nullopt, py::arg("rotary_inv_freq") = std::nullopt,
|
||||
py::arg("rotary_cos_sin") = std::nullopt, py::arg("latent_cache") = std::nullopt,
|
||||
py::arg("q_pe") = std::nullopt, py::arg("block_ids_per_seq") = std::nullopt,
|
||||
py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"), py::arg("update_kv_cache"),
|
||||
py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"), py::arg("num_kv_heads"),
|
||||
py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt, py::arg("max_num_requests"),
|
||||
py::arg("max_context_length"), py::arg("attention_window_size"), py::arg("sink_token_length"),
|
||||
py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"), py::arg("q_scaling"),
|
||||
py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"), py::arg("rotary_embedding_base"),
|
||||
py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
|
||||
py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"),
|
||||
py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"),
|
||||
py::arg("chunked_prefill_buffer_batch_size") = std::nullopt, py::arg("q_lora_rank") = std::nullopt,
|
||||
|
||||
@ -75,7 +75,6 @@ public:
|
||||
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
|
||||
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
|
||||
torch::Tensor host_context_lengths, torch::optional<torch::Tensor> kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
|
||||
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
|
||||
@ -136,7 +135,6 @@ public:
|
||||
torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v, torch::Tensor sequence_length,
|
||||
torch::Tensor host_past_key_value_lengths, int32_t const total_kv_len, torch::Tensor context_lengths,
|
||||
torch::Tensor host_context_lengths, torch::optional<torch::Tensor> kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
|
||||
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
|
||||
@ -289,7 +287,6 @@ public:
|
||||
|
||||
// Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same
|
||||
// unless each layer has different attention window sizes.
|
||||
// the kv_cache capacity.
|
||||
int const max_attention_window_size = beam_width == 1 ? attention_window_size
|
||||
: cache_indirection.has_value() ? cache_indirection.value().size(2)
|
||||
: attention_window_size;
|
||||
@ -310,10 +307,6 @@ public:
|
||||
= static_cast<KVBlockArray::DataType*>(op.useKVCache() && kv_cache_block_offsets.has_value()
|
||||
? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr()
|
||||
: nullptr);
|
||||
KVBlockArray::DataType* host_block_offsets
|
||||
= static_cast<KVBlockArray::DataType*>(op.useKVCache() && host_kv_cache_block_offsets.has_value()
|
||||
? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr()
|
||||
: nullptr);
|
||||
|
||||
// The cache element size in bits.
|
||||
int cache_elem_bits = op.getKvCacheElemSizeInBits<T>();
|
||||
@ -463,7 +456,6 @@ public:
|
||||
{
|
||||
common_enqueue_params.input_seq_length = max_context_q_len;
|
||||
AttentionOp::EnqueueContextParams<T> enqueue_params{common_enqueue_params};
|
||||
enqueue_params.host_block_offsets = host_block_offsets;
|
||||
enqueue_params.batch_size = num_seqs;
|
||||
enqueue_params.k_ptr = k_ptr;
|
||||
enqueue_params.v_ptr = v_ptr;
|
||||
@ -606,19 +598,19 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
std::optional<torch::Tensor> output_sf, std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length,
|
||||
torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths,
|
||||
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
|
||||
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
|
||||
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
|
||||
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
|
||||
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
|
||||
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
|
||||
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
|
||||
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
|
||||
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
|
||||
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
|
||||
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
|
||||
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
|
||||
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
|
||||
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
|
||||
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
|
||||
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
|
||||
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
|
||||
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
|
||||
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
|
||||
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
|
||||
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
|
||||
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
|
||||
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
|
||||
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
|
||||
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
|
||||
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,
|
||||
@ -639,8 +631,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
{
|
||||
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
|
||||
// Use these tensors to infer if the attention is using KV cache
|
||||
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value()
|
||||
&& host_kv_cache_pool_pointers.has_value() && host_kv_cache_pool_mapping.has_value();
|
||||
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_pool_pointers.has_value()
|
||||
&& host_kv_cache_pool_mapping.has_value();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(is_mla_enable || is_fused_qkv, "Only fused QKV is supported for non-MLA attention now");
|
||||
TLLM_CHECK_WITH_INFO(update_kv_cache, "KV cache update cannot be disabled now");
|
||||
@ -894,13 +886,12 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
/*num_seqs=*/num_contexts, token_offset,
|
||||
/*num_tokens=*/num_ctx_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv_or_q, k, v,
|
||||
sequence_length, host_past_key_value_lengths, ctx_total_kv_len, context_lengths, host_context_lengths,
|
||||
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
|
||||
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
|
||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
||||
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
|
||||
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
|
||||
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
|
||||
kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
|
||||
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
|
||||
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_tensor_params, softmax_stats_tensor,
|
||||
spec_decoding_tensor_params, attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices,
|
||||
sparse_attn_offsets, sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens,
|
||||
fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
|
||||
}
|
||||
|
||||
if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
|
||||
@ -913,13 +904,12 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
/*num_seqs=*/num_generations, token_offset,
|
||||
/*num_tokens=*/num_gen_tokens, predicted_tokens_per_seq, workspace, output, output_sf, qkv_or_q, k, v,
|
||||
sequence_length, host_past_key_value_lengths, gen_total_kv_len, context_lengths, host_context_lengths,
|
||||
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
|
||||
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
|
||||
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
|
||||
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
|
||||
attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices, sparse_attn_offsets,
|
||||
sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens, fmha_scheduler_counter,
|
||||
mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
|
||||
kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping, cache_indirection,
|
||||
kv_scale_orig_quant, kv_scale_quant_orig, out_scale, rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe,
|
||||
block_ids_per_seq, mrope_rotary_cos_sin, mrope_position_deltas, mla_tensor_params, softmax_stats_tensor,
|
||||
spec_decoding_tensor_params, attention_sinks, sparse_kv_indices, sparse_kv_offsets, sparse_attn_indices,
|
||||
sparse_attn_offsets, sparse_attn_indices_block_size, sparse_mla_topk_value, cu_q_seqlens, cu_kv_seqlens,
|
||||
fmha_scheduler_counter, mla_bmm1_scale, mla_bmm2_scale, quant_q_buffer);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
|
||||
|
||||
@ -42,19 +42,19 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
||||
std::optional<torch::Tensor> output_sf, std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length,
|
||||
torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths,
|
||||
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
|
||||
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
|
||||
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
|
||||
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
|
||||
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
|
||||
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
|
||||
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
|
||||
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
|
||||
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
|
||||
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
|
||||
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
|
||||
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
|
||||
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_mapping, std::optional<torch::Tensor> cache_indirection,
|
||||
std::optional<torch::Tensor> kv_scale_orig_quant, std::optional<torch::Tensor> kv_scale_quant_orig,
|
||||
std::optional<torch::Tensor> out_scale, std::optional<torch::Tensor> rotary_inv_freq,
|
||||
std::optional<torch::Tensor> rotary_cos_sin, std::optional<torch::Tensor> latent_cache,
|
||||
std::optional<torch::Tensor> q_pe, std::optional<torch::Tensor> block_ids_per_seq,
|
||||
std::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
|
||||
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
|
||||
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
|
||||
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
|
||||
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
|
||||
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
|
||||
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
|
||||
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
|
||||
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,
|
||||
|
||||
@ -121,8 +121,7 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
|
||||
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
|
||||
torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor host_context_lengths,
|
||||
int64_t const num_contexts, std::optional<torch::Tensor> kv_cache_block_offsets,
|
||||
std::optional<torch::Tensor> host_kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_mapping,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
|
||||
torch::optional<torch::Tensor> kv_scale_orig_quant, // [1] q,k quant scale
|
||||
torch::optional<torch::Tensor> kv_scale_quant_orig, // [1] bmm quant scale
|
||||
torch::optional<torch::Tensor> out_scale, // [1] output quant scale
|
||||
@ -147,8 +146,8 @@ void MLARopeGeneration(torch::Tensor fused_q, // [tokens, num_heads, (nope_dim +
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
host_kv_cache_pool_mapping.has_value(), "KV cache pool mapping is required for MLA generation.");
|
||||
|
||||
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_block_offsets.has_value()
|
||||
&& host_kv_cache_pool_pointers.has_value() && host_kv_cache_pool_mapping.has_value();
|
||||
bool const use_kv_cache = kv_cache_block_offsets.has_value() && host_kv_cache_pool_pointers.has_value()
|
||||
&& host_kv_cache_pool_mapping.has_value();
|
||||
|
||||
int32_t const num_seqs = host_context_lengths.size(0);
|
||||
|
||||
@ -331,7 +330,6 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
", Tensor host_context_lengths"
|
||||
", int num_contexts"
|
||||
", Tensor? kv_cache_block_offsets"
|
||||
", Tensor? host_kv_cache_block_offsets"
|
||||
", Tensor? host_kv_cache_pool_pointers"
|
||||
", Tensor? host_kv_cache_pool_mapping"
|
||||
", Tensor? kv_scale_orig_quant"
|
||||
|
||||
@ -159,12 +159,11 @@ KVBlockArray createKVBlockArray(int num_contexts, int max_blocks_per_sequence, i
|
||||
|
||||
std::vector<torch::Tensor> loadPagedKVCacheForMLA(torch::ScalarType out_dtype, int64_t const num_contexts,
|
||||
int64_t const num_ctx_cached_tokens, int64_t const max_ctx_cached_kv_len, torch::Tensor& cu_ctx_cached_kv_lens,
|
||||
torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_block_offsets,
|
||||
torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping,
|
||||
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
|
||||
int64_t const layer_idx, int64_t const lora_size, int64_t const rope_size, int64_t const tokens_per_block,
|
||||
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
|
||||
int64_t const quant_mode)
|
||||
torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers,
|
||||
torch::Tensor const& host_kv_cache_pool_mapping, torch::optional<torch::Tensor> kv_scale_orig_quant,
|
||||
torch::optional<torch::Tensor> kv_scale_quant_orig, int64_t const layer_idx, int64_t const lora_size,
|
||||
int64_t const rope_size, int64_t const tokens_per_block, int64_t const attention_window_size,
|
||||
int64_t const sink_token_length, int64_t const beam_width, int64_t const quant_mode)
|
||||
{
|
||||
TORCH_CHECK(out_dtype == torch::kFloat16 || out_dtype == torch::kFloat32 || out_dtype == torch::kBFloat16,
|
||||
"out_dtype only support float16, float32, bfloat16");
|
||||
@ -348,11 +347,11 @@ void MLARopeAppendPagedKVAssignQ(torch::Tensor& q, torch::Tensor& latent_cache,
|
||||
torch::Tensor const& cu_ctx_cached_kv_lens, torch::Tensor const& cu_seq_lens,
|
||||
int64_t const max_input_uncached_seq_len, torch::Tensor const& cos_sin_cache, int64_t const head_num,
|
||||
int64_t const nope_size, int64_t const rope_size, int64_t const lora_size,
|
||||
torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_block_offsets,
|
||||
torch::Tensor const& host_kv_cache_pool_pointers, torch::Tensor const& host_kv_cache_pool_mapping,
|
||||
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
|
||||
int64_t const layer_idx, int64_t const tokens_per_block, int64_t const attention_window_size,
|
||||
int64_t const sink_token_length, int64_t const beam_width, int64_t const quant_mode)
|
||||
torch::Tensor const& kv_cache_block_offsets, torch::Tensor const& host_kv_cache_pool_pointers,
|
||||
torch::Tensor const& host_kv_cache_pool_mapping, torch::optional<torch::Tensor> kv_scale_orig_quant,
|
||||
torch::optional<torch::Tensor> kv_scale_quant_orig, int64_t const layer_idx, int64_t const tokens_per_block,
|
||||
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
|
||||
int64_t const quant_mode)
|
||||
{
|
||||
auto input_dtype = q.scalar_type();
|
||||
TORCH_CHECK(input_dtype == torch::kFloat16 || input_dtype == torch::kFloat32 || input_dtype == torch::kBFloat16);
|
||||
@ -482,7 +481,6 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
", int max_ctx_cached_kv_len"
|
||||
", Tensor cu_ctx_cached_kv_lens"
|
||||
", Tensor kv_cache_block_offsets"
|
||||
", Tensor host_kv_cache_block_offsets"
|
||||
", Tensor host_kv_cache_pool_pointers"
|
||||
", Tensor host_kv_cache_pool_mapping"
|
||||
", Tensor? kv_scale_orig_quant"
|
||||
@ -550,7 +548,6 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
", int rope_size"
|
||||
", int lora_size"
|
||||
", Tensor kv_cache_block_offsets"
|
||||
", Tensor host_kv_cache_block_offsets"
|
||||
", Tensor host_kv_cache_pool_pointers"
|
||||
", Tensor host_kv_cache_pool_mapping"
|
||||
", Tensor? kv_scale_orig_quant"
|
||||
|
||||
@ -1679,16 +1679,12 @@ class DSATrtllmAttention(TrtllmAttention):
|
||||
max_seq_len = metadata.max_gen_seq_len
|
||||
block_offsets = metadata.kv_cache_block_offsets[:, metadata.
|
||||
num_contexts:]
|
||||
host_block_offsets = metadata.host_kv_cache_block_offsets[:,
|
||||
metadata.
|
||||
num_contexts:]
|
||||
else:
|
||||
cached_token_indptr = metadata.ctx_cached_token_indptr
|
||||
kv_indptr = metadata.ctx_kv_indptr
|
||||
num_seqs = metadata.num_contexts
|
||||
max_seq_len = metadata.max_ctx_seq_len
|
||||
block_offsets = metadata.kv_cache_block_offsets
|
||||
host_block_offsets = metadata.host_kv_cache_block_offsets
|
||||
assert self.is_mla_enable and self.mla_params is not None
|
||||
assert metadata.kv_cache_manager is not None
|
||||
|
||||
@ -1708,7 +1704,6 @@ class DSATrtllmAttention(TrtllmAttention):
|
||||
self.mla_params.qk_rope_head_dim,
|
||||
self.mla_params.kv_lora_rank,
|
||||
block_offsets,
|
||||
host_block_offsets,
|
||||
metadata.kv_cache_manager.kv_cache_pool_pointers,
|
||||
metadata.kv_cache_manager.kv_cache_pool_mapping,
|
||||
self.kv_scale_orig_quant,
|
||||
|
||||
@ -35,7 +35,6 @@ class TrtllmAttentionWrapper:
|
||||
host_context_lengths: torch.Tensor
|
||||
host_request_types: torch.Tensor
|
||||
kv_cache_block_offsets: torch.Tensor
|
||||
host_kv_cache_block_offsets: torch.Tensor
|
||||
host_kv_cache_pool_pointers: torch.Tensor
|
||||
host_kv_cache_pool_mapping: torch.Tensor
|
||||
workspace: Optional[torch.Tensor]
|
||||
@ -181,7 +180,6 @@ class TrtllmAttentionWrapper:
|
||||
host_context_lengths: torch.Tensor = ...,
|
||||
host_request_types: torch.Tensor = ...,
|
||||
kv_cache_block_offsets: Optional[torch.Tensor] = None,
|
||||
host_kv_cache_block_offsets: Optional[torch.Tensor] = None,
|
||||
host_kv_cache_pool_pointers: Optional[torch.Tensor] = None,
|
||||
host_kv_cache_pool_mapping: Optional[torch.Tensor] = None,
|
||||
block_ids_per_seq: Optional[torch.Tensor] = None,
|
||||
@ -242,8 +240,7 @@ class TrtllmAttentionWrapper:
|
||||
context_lengths (torch.Tensor): The context-phase sequence length of each request with shape (batch_size) on GPU.
|
||||
host_context_lengths (torch.Tensor): Same as context_lengths, but on CPU.
|
||||
host_request_types (torch.Tensor): The tensor that indicates whether a request is in context or generation phase, with shape (batch_size) on CPU.
|
||||
kv_cache_block_offsets (torch.Tensor): The offsets to the blocks inside KV cache pools on GPU, its shape is (num_pools, max_batch_size * max_beam_width, 2, max_blocks_per_sequence), one for each block. If kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention.
|
||||
host_kv_cache_block_offsets (torch.Tensor): Same as kv_cache_block_offsets, but on CPU.
|
||||
kv_cache_block_offsets (torch.Tensor): The offsets to the blocks inside KV cache pools on GPU, its shape is (num_pools, max_batch_size * max_beam_width, 2, max_blocks_per_sequence), one for each block. If kv_cache_block_offsets, host_kv_cache_pool_pointers, host_kv_cache_pool_mapping are all None, the attention will be no cache attention.
|
||||
host_kv_cache_pool_pointers (torch.Tensor): The pointers to the KV cache pools on CPU, its shape is (num_pools, 2), one for primary pool in GPU memory, one for secondary pool in CPU memory.
|
||||
host_kv_cache_pool_mapping (torch.Tensor): The index of the pool used by each attention layer on CPU, its shape is (num_local_attention_layers). The local attention layers mean all attention layers in the current PP stage in the pipeline parallelism case.
|
||||
workspace (torch.Tensor): An optional workspace tensor on GPU.
|
||||
@ -284,7 +281,6 @@ class TrtllmAttentionWrapper:
|
||||
self.host_context_lengths = host_context_lengths
|
||||
self.host_request_types = host_request_types
|
||||
self.kv_cache_block_offsets = kv_cache_block_offsets
|
||||
self.host_kv_cache_block_offsets = host_kv_cache_block_offsets
|
||||
self.host_kv_cache_pool_pointers = host_kv_cache_pool_pointers
|
||||
self.host_kv_cache_pool_mapping = host_kv_cache_pool_mapping
|
||||
self.workspace = workspace
|
||||
@ -511,7 +507,6 @@ class TrtllmAttentionWrapper:
|
||||
self.host_context_lengths,
|
||||
self.host_request_types,
|
||||
self.kv_cache_block_offsets,
|
||||
self.host_kv_cache_block_offsets,
|
||||
self.host_kv_cache_pool_pointers,
|
||||
self.host_kv_cache_pool_mapping,
|
||||
self.cache_indirection,
|
||||
@ -789,11 +784,6 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.host_kv_cache_block_offsets = torch.empty_like(
|
||||
self.kv_cache_block_offsets,
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
self.block_ids_per_seq = None
|
||||
self.kv_block_ids_per_seq = None
|
||||
if self.enable_flash_mla:
|
||||
@ -933,7 +923,6 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
|
||||
# set params that are used in wrapper.plan()
|
||||
self.kv_cache_block_offsets = None
|
||||
self.host_kv_cache_block_offsets = None
|
||||
self.block_ids_per_seq = None
|
||||
|
||||
prompt_lens = torch.tensor(
|
||||
@ -986,19 +975,10 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
# kv block offsets
|
||||
assert self.request_ids is not None
|
||||
if self.kv_cache_manager is not None:
|
||||
# Copy blocks for all context requests
|
||||
self.kv_cache_manager.impl.copy_batch_block_offsets(
|
||||
self.host_kv_cache_block_offsets,
|
||||
self.request_ids[:self.num_contexts], 1, 0)
|
||||
# Copy blocks for all generation requests
|
||||
self.kv_cache_manager.impl.copy_batch_block_offsets(
|
||||
self.host_kv_cache_block_offsets,
|
||||
self.request_ids[self.num_contexts:], self.beam_width,
|
||||
self.num_contexts)
|
||||
for pool_idx in range(self.host_kv_cache_block_offsets.shape[0]):
|
||||
self.kv_cache_block_offsets[pool_idx, :self.num_seqs].copy_(
|
||||
self.host_kv_cache_block_offsets[pool_idx, :self.num_seqs],
|
||||
non_blocking=True)
|
||||
self.kv_cache_manager.copy_batch_block_offsets(
|
||||
self.kv_cache_block_offsets, self.request_ids, self.beam_width,
|
||||
self.num_contexts, self.num_seqs)
|
||||
|
||||
error_message = (
|
||||
f"The max KV cache length of input sequences ({self.kv_lens[:self.num_seqs].max()}) "
|
||||
f"exceeds the KV cache manager's maximum supported length "
|
||||
@ -1723,7 +1703,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
host_context_lengths=metadata.prompt_lens_cpu_runtime,
|
||||
host_request_types=metadata.host_request_types_runtime,
|
||||
kv_cache_block_offsets=metadata.kv_cache_block_offsets,
|
||||
host_kv_cache_block_offsets=metadata.host_kv_cache_block_offsets,
|
||||
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
|
||||
host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping,
|
||||
block_ids_per_seq=metadata.block_ids_per_seq,
|
||||
@ -1847,7 +1826,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
metadata.max_ctx_kv_len,
|
||||
metadata.ctx_kv_indptr,
|
||||
metadata.kv_cache_block_offsets,
|
||||
metadata.host_kv_cache_block_offsets,
|
||||
metadata.kv_cache_manager.kv_cache_pool_pointers,
|
||||
metadata.kv_cache_manager.kv_cache_pool_mapping,
|
||||
self.kv_scale_orig_quant,
|
||||
@ -1938,7 +1916,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
self.mla_params.qk_rope_head_dim,
|
||||
self.mla_params.kv_lora_rank,
|
||||
metadata.kv_cache_block_offsets,
|
||||
metadata.host_kv_cache_block_offsets,
|
||||
metadata.kv_cache_manager.kv_cache_pool_pointers,
|
||||
metadata.kv_cache_manager.kv_cache_pool_mapping,
|
||||
self.kv_scale_orig_quant,
|
||||
@ -2052,7 +2029,6 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
metadata.prompt_lens_cpu_runtime, # host_context_lengths,
|
||||
metadata.num_contexts,
|
||||
metadata.kv_cache_block_offsets,
|
||||
metadata.host_kv_cache_block_offsets,
|
||||
metadata.kv_cache_manager.kv_cache_pool_pointers,
|
||||
metadata.kv_cache_manager.kv_cache_pool_mapping,
|
||||
self.kv_scale_orig_quant,
|
||||
|
||||
@ -889,7 +889,6 @@ def _register_fake():
|
||||
host_context_lengths: torch.Tensor,
|
||||
num_contexts: int,
|
||||
kv_cache_block_offsets: Optional[torch.Tensor],
|
||||
host_kv_cache_block_offsets: Optional[torch.Tensor],
|
||||
host_kv_cache_pool_pointers: Optional[torch.Tensor],
|
||||
host_kv_cache_pool_mapping: Optional[torch.Tensor],
|
||||
kv_scale_orig_quant: Optional[torch.Tensor],
|
||||
|
||||
@ -407,6 +407,14 @@ class KVCacheManager(BaseResourceManager):
|
||||
self.num_pools = self.impl.num_pools
|
||||
self.max_blocks_per_seq = self.impl.max_blocks_per_seq
|
||||
self.enable_block_reuse = kv_cache_config.enable_block_reuse
|
||||
self.host_kv_cache_block_offsets = torch.empty(self.num_pools,
|
||||
max_batch_size *
|
||||
max_beam_width,
|
||||
2,
|
||||
self.max_blocks_per_seq,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True,
|
||||
device='cpu')
|
||||
|
||||
def shutdown(self):
|
||||
self.impl.release_pools()
|
||||
@ -650,7 +658,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
accepted_draft_token_offsets, packed_accepted_draft_tokens_indices, rewind_draft_token_separate_adjustments = self.locate_accepted_draft_tokens(
|
||||
requests)
|
||||
past_key_value_lengths = attn_metadata.kv_lens_cuda[:len(requests)]
|
||||
if attn_metadata.kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_pool_pointers is not None and attn_metadata.host_kv_cache_pool_mapping is not None:
|
||||
if attn_metadata.kv_cache_block_offsets is not None and attn_metadata.host_kv_cache_pool_pointers is not None and attn_metadata.host_kv_cache_pool_mapping is not None:
|
||||
use_paged_kv_cache = True
|
||||
else:
|
||||
use_paged_kv_cache = False
|
||||
@ -1262,6 +1270,20 @@ class KVCacheManager(BaseResourceManager):
|
||||
else:
|
||||
return None
|
||||
|
||||
def copy_batch_block_offsets(self, dst_tensor: torch.Tensor,
|
||||
request_ids: List[int], beam_width: int,
|
||||
num_context: int, num_seqs: int):
|
||||
self.impl.copy_batch_block_offsets(self.host_kv_cache_block_offsets,
|
||||
request_ids[:num_context], 1, 0)
|
||||
self.impl.copy_batch_block_offsets(self.host_kv_cache_block_offsets,
|
||||
request_ids[num_context:],
|
||||
beam_width, num_context)
|
||||
|
||||
for pool_idx in range(self.host_kv_cache_block_offsets.shape[0]):
|
||||
dst_tensor[pool_idx, :num_seqs].copy_(
|
||||
self.host_kv_cache_block_offsets[pool_idx, :num_seqs],
|
||||
non_blocking=True)
|
||||
|
||||
def reset_reuse_state(self):
|
||||
"""Reset the reuse state of the KV cache manager."""
|
||||
self.impl.reset_reuse_state()
|
||||
|
||||
@ -355,7 +355,7 @@ def _make_latent_cache_gen(
|
||||
):
|
||||
if rank == 0:
|
||||
assert ref_attn_metadata is not None
|
||||
kv_cache_block_offsets = ref_attn_metadata.host_kv_cache_block_offsets
|
||||
kv_cache_block_offsets = ref_attn_metadata.kv_cache_manager.host_kv_cache_block_offsets
|
||||
kv_buffer = ref_attn_metadata.kv_cache_manager.get_buffers(0)
|
||||
ret = input_ctx_bs.new_empty(
|
||||
(world_size - 1, input_ctx_bs.shape[0], mla.kv_lora_rank + mla.qk_rope_head_dim)
|
||||
|
||||
@ -56,7 +56,7 @@ def test_kv_lens_runtime_with_eagle3_one_model():
|
||||
mock_kv_cache_manager.max_blocks_per_seq = 16
|
||||
mock_kv_cache_manager.max_batch_size = num_seqs
|
||||
mock_kv_cache_manager.max_seq_len = 512 # Large enough to hold our test sequences
|
||||
mock_kv_cache_manager.impl.copy_batch_block_offsets = MagicMock()
|
||||
mock_kv_cache_manager.copy_batch_block_offsets = MagicMock()
|
||||
|
||||
attn_metadata = TrtllmAttentionMetadata(
|
||||
max_num_requests=num_seqs,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user