TensorRT-LLMs/cpp/tensorrt_llm/thop/attentionOp.h
Jin Li c04563657e
[TRTLLM-7735][feat] Attention NVFP4 out support for torch compile (#9740)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-12-27 00:07:20 +08:00

79 lines
4.6 KiB
C++

/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <optional>
#include <torch/extension.h>
#include "tensorrt_llm/common/config.h"
TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{
/**
* @brief Attention operation for TensorRT-LLM
*
* This function performs multi-head attention computation in-place, supporting both
* context and generation phases with various optimization features including:
* - Fused QKV processing
* - KV cache management
* - Multiple position embedding types (RoPE, ALiBi, etc.)
* - Quantization support (FP8, FP4, etc.)
* - Multi-layer attention (MLA)
* - Speculative decoding
*/
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
std::optional<torch::Tensor> output_sf, std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length,
torch::Tensor host_past_key_value_lengths, torch::Tensor host_total_kv_lens, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
std::optional<int64_t> chunked_prefill_buffer_batch_size, std::optional<int64_t> q_lora_rank,
std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t> attention_chunk_size,
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params,
std::optional<torch::Tensor> sparse_kv_indices, std::optional<torch::Tensor> sparse_kv_offsets,
std::optional<torch::Tensor> sparse_attn_indices, std::optional<torch::Tensor> sparse_attn_offsets,
int64_t const sparse_attn_indices_block_size, std::optional<int64_t> sparse_mla_topk,
std::optional<double> skip_softmax_threshold_scale_factor_prefill,
std::optional<double> skip_softmax_threshold_scale_factor_decode, std::optional<torch::Tensor> skip_softmax_stat,
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer);
} // namespace torch_ext
TRTLLM_NAMESPACE_END