/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include "tensorrt_llm/common/config.h" TRTLLM_NAMESPACE_BEGIN namespace torch_ext { /** * @brief Attention operation for TensorRT-LLM * * This function performs multi-head attention computation in-place, supporting both * context and generation phases with various optimization features including: * - Fused QKV processing * - KV cache management * - Multiple position embedding types (RoPE, ALiBi, etc.) * - Quantization support (FP8, FP4, etc.) * - Multi-layer attention (MLA) * - Speculative decoding */ void attention(torch::Tensor q, std::optional k, std::optional v, torch::Tensor& output, std::optional output_sf, std::optional out_dtype, std::optional workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types, std::optional kv_cache_block_offsets, std::optional host_kv_cache_block_offsets, std::optional host_kv_cache_pool_pointers, std::optional host_kv_cache_pool_mapping, std::optional cache_indirection, std::optional kv_scale_orig_quant, std::optional kv_scale_quant_orig, std::optional out_scale, std::optional rotary_inv_freq, std::optional rotary_cos_sin, std::optional latent_cache, std::optional q_pe, std::optional block_ids_per_seq, std::optional attention_sinks, bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size, std::optional const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type, std::vector rotary_embedding_scales, std::vector rotary_embedding_max_position_info, bool const use_paged_context_fmha, std::optional attention_input_type, bool is_mla_enable, std::optional chunked_prefill_buffer_batch_size, std::optional q_lora_rank, std::optional kv_lora_rank, std::optional qk_nope_head_dim, std::optional qk_rope_head_dim, std::optional v_head_dim, std::optional mrope_rotary_cos_sin, std::optional mrope_position_deltas, std::vector> mla_tensor_params, std::optional attention_chunk_size, std::optional softmax_stats_tensor, std::vector spec_decoding_bool_params, std::vector> spec_decoding_tensor_params, std::optional sparse_kv_indices, std::optional sparse_kv_offsets, std::optional sparse_attn_indices, std::optional sparse_attn_offsets, int64_t const sparse_attn_indices_block_size, std::optional sparse_mla_topk, std::optional cu_q_seqlens, std::optional cu_kv_seqlens, std::optional fmha_scheduler_counter, std::optional mla_bmm1_scale, std::optional mla_bmm2_scale, std::optional quant_q_buffer); } // namespace torch_ext TRTLLM_NAMESPACE_END