[TRTLLM-6906][chore] Using pybind to bind functions in thop/attentionOp (#6745)

Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
This commit is contained in:
Liao Lanyu 2025-08-12 16:45:16 +08:00 committed by GitHub
parent 27fc35175e
commit f7c13a4aa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 269 additions and 96 deletions

View File

@ -17,6 +17,7 @@ set(SRCS
testing/modelSpecBinding.cpp
runtime/moeBindings.cpp
userbuffers/bindings.cpp
thop/bindings.cpp
../runtime/ipcNvlsMemory.cu
bindings.cpp)
@ -42,7 +43,8 @@ target_link_libraries(
${Python3_LIBRARIES}
${TORCH_LIBRARIES}
torch_python
${CUDA_NVML_LIB})
${CUDA_NVML_LIB}
th_common)
target_compile_definitions(
${TRTLLM_NB_MODULE} PUBLIC TRTLLM_NB_MODULE=${TRTLLM_NB_MODULE}
PYBIND11_DETAILED_ERROR_MESSAGES=1)

View File

@ -39,6 +39,7 @@
#include "tensorrt_llm/nanobind/executor/bindings.h"
#include "tensorrt_llm/nanobind/runtime/bindings.h"
#include "tensorrt_llm/nanobind/testing/modelSpecBinding.h"
#include "tensorrt_llm/nanobind/thop/bindings.h"
#include "tensorrt_llm/nanobind/userbuffers/bindings.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
@ -124,9 +125,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings");
tensorrt_llm::nanobind::executor::initBindings(mExecutor);
tensorrt_llm::nanobind::runtime::initBindingsEarly(mInternalRuntime);
tensorrt_llm::nanobind::thop::initBindings(mInternalThop);
auto buildInfo = m.def_submodule("BuildInfo");
buildInfo.attr("ENABLE_MULTI_DEVICE") = nb::int_(ENABLE_MULTI_DEVICE);

View File

@ -0,0 +1,59 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/vector.h>
#include <tensorrt_llm/thop/attentionOp.h>
#include <torch/extension.h>
namespace nb = nanobind;
namespace tensorrt_llm::nanobind::thop
{
void initBindings(nb::module_& m)
{
m.def("attention", &torch_ext::attention,
// Parameters with default values using std::nullopt for optional arguments
nb::arg("q"), nb::arg("k") = std::nullopt, nb::arg("v") = std::nullopt, nb::arg("output"),
nb::arg("output_sf") = std::nullopt, nb::arg("out_dtype") = std::nullopt, nb::arg("workspace_") = std::nullopt,
nb::arg("sequence_length"), nb::arg("host_past_key_value_lengths"), nb::arg("context_lengths"),
nb::arg("host_context_lengths"), nb::arg("host_request_types"),
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt,
nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt,
nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt,
nb::arg("kv_scale_quant_orig") = std::nullopt, nb::arg("out_scale") = std::nullopt,
nb::arg("rotary_inv_freq") = std::nullopt, nb::arg("rotary_cos_sin") = std::nullopt,
nb::arg("latent_cache") = std::nullopt, nb::arg("q_pe") = std::nullopt,
nb::arg("block_ids_per_seq") = std::nullopt, nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"),
nb::arg("update_kv_cache"), nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"),
nb::arg("num_kv_heads"), nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt,
nb::arg("max_num_requests"), nb::arg("max_context_length"), nb::arg("attention_window_size"),
nb::arg("sink_token_length"), nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"),
nb::arg("q_scaling"), nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"),
nb::arg("rotary_embedding_base"), nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
nb::arg("rotary_embedding_max_position_info"), nb::arg("use_paged_context_fmha"),
nb::arg("attention_input_type") = std::nullopt, nb::arg("is_mla_enable"), nb::arg("q_lora_rank") = std::nullopt,
nb::arg("kv_lora_rank") = std::nullopt, nb::arg("qk_nope_head_dim") = std::nullopt,
nb::arg("qk_rope_head_dim") = std::nullopt, nb::arg("v_head_dim") = std::nullopt,
nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt,
nb::arg("mla_context_paged_kv") = std::nullopt, nb::arg("mla_context_kv_cache_block_offsets") = std::nullopt,
nb::arg("attention_chunk_size") = std::nullopt, nb::arg("softmax_stats_tensor") = std::nullopt,
nb::arg("spec_decoding_bool_params"), nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation");
}
} // namespace tensorrt_llm::nanobind::thop

View File

@ -0,0 +1,27 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/nanobind/common/customCasters.h"
#include <nanobind/nanobind.h>
namespace tensorrt_llm::nanobind::thop
{
void initBindings(nb::module_& m);
} // namespace tensorrt_llm::nanobind::thop

View File

@ -18,6 +18,7 @@ set(SRCS
runtime/moeBindings.cpp
userbuffers/bindings.cpp
../runtime/ipcNvlsMemory.cu
thop/bindings.cpp
bindings.cpp)
include_directories(${PROJECT_SOURCE_DIR}/include)
@ -43,7 +44,8 @@ target_link_libraries(
${Python3_LIBRARIES}
${TORCH_LIBRARIES}
torch_python
${CUDA_NVML_LIB})
${CUDA_NVML_LIB}
th_common)
target_compile_definitions(
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
PYBIND11_DETAILED_ERROR_MESSAGES=1)

View File

@ -33,6 +33,7 @@
#include "tensorrt_llm/pybind/executor/bindings.h"
#include "tensorrt_llm/pybind/runtime/bindings.h"
#include "tensorrt_llm/pybind/testing/modelSpecBinding.h"
#include "tensorrt_llm/pybind/thop/bindings.h"
#include "tensorrt_llm/pybind/userbuffers/bindings.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
@ -116,9 +117,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings");
tensorrt_llm::pybind::executor::initBindings(mExecutor);
tensorrt_llm::pybind::runtime::initBindingsEarly(mInternalRuntime);
tensorrt_llm::pybind::thop::initBindings(mInternalThop);
auto buildInfo = m.def_submodule("BuildInfo");
buildInfo.attr("ENABLE_MULTI_DEVICE") = py::int_(ENABLE_MULTI_DEVICE);

View File

@ -0,0 +1,59 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <tensorrt_llm/thop/attentionOp.h>
#include <torch/extension.h>
namespace py = pybind11;
namespace tensorrt_llm::pybind::thop
{
void initBindings(pybind11::module_& m)
{
m.def("attention", &torch_ext::attention,
// Parameters with default values using std::nullopt for optional arguments
py::arg("q"), py::arg("k") = std::nullopt, py::arg("v") = std::nullopt, py::arg("output"),
py::arg("output_sf") = std::nullopt, py::arg("out_dtype") = std::nullopt, py::arg("workspace_") = std::nullopt,
py::arg("sequence_length"), py::arg("host_past_key_value_lengths"), py::arg("context_lengths"),
py::arg("host_context_lengths"), py::arg("host_request_types"),
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt,
py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt,
py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt,
py::arg("kv_scale_quant_orig") = std::nullopt, py::arg("out_scale") = std::nullopt,
py::arg("rotary_inv_freq") = std::nullopt, py::arg("rotary_cos_sin") = std::nullopt,
py::arg("latent_cache") = std::nullopt, py::arg("q_pe") = std::nullopt,
py::arg("block_ids_per_seq") = std::nullopt, py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"),
py::arg("update_kv_cache"), py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"),
py::arg("num_kv_heads"), py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt,
py::arg("max_num_requests"), py::arg("max_context_length"), py::arg("attention_window_size"),
py::arg("sink_token_length"), py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"),
py::arg("q_scaling"), py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"),
py::arg("rotary_embedding_base"), py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"),
py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"), py::arg("q_lora_rank") = std::nullopt,
py::arg("kv_lora_rank") = std::nullopt, py::arg("qk_nope_head_dim") = std::nullopt,
py::arg("qk_rope_head_dim") = std::nullopt, py::arg("v_head_dim") = std::nullopt,
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
py::arg("mla_context_paged_kv") = std::nullopt, py::arg("mla_context_kv_cache_block_offsets") = std::nullopt,
py::arg("attention_chunk_size") = std::nullopt, py::arg("softmax_stats_tensor") = std::nullopt,
py::arg("spec_decoding_bool_params"), py::arg("spec_decoding_tensor_params"), "Multi-head attention operation");
}
} // namespace tensorrt_llm::pybind::thop

View File

@ -0,0 +1,27 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::thop
{
void initBindings(pybind11::module_& m);
} // namespace tensorrt_llm::pybind::thop

View File

@ -21,6 +21,7 @@
#include "tensorrt_llm/kernels/mlaKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include "tensorrt_llm/thop/attentionOp.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <cstdint>
#include <functional>
@ -420,32 +421,31 @@ using RunnerPtr = std::shared_ptr<torch_ext::trtllm::attention::RunnerBase>;
using torch_ext::trtllm::attention::Runner;
using torch_ext::trtllm::attention::AttentionInputType;
void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v,
torch::Tensor& output, torch::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
torch::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types,
torch::optional<torch::Tensor> kv_cache_block_offsets, torch::optional<torch::Tensor> host_kv_cache_block_offsets,
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq,
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
c10::ArrayRef<double> rotary_embedding_scales, c10::ArrayRef<int64_t> rotary_embedding_max_position_info,
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
c10::List<bool> spec_decoding_bool_params, c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params)
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params)
{
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
// Use these tensors to infer if the attention is using KV cache
@ -751,78 +751,5 @@ bool attention_supports_nvfp4_output(int64_t const num_heads, int64_t const num_
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"attention_inplace("
"Tensor q"
", Tensor? k"
", Tensor? v"
", Tensor(a!) output"
", Tensor(b!)? output_sf"
", ScalarType? out_dtype"
", Tensor? workspace"
", Tensor sequence_length"
", Tensor host_past_key_value_lengths"
", Tensor context_lengths"
", Tensor host_context_lengths"
", Tensor host_request_types"
", Tensor? kv_cache_block_offsets"
", Tensor? host_kv_cache_block_offsets"
", Tensor? host_kv_cache_pool_pointers"
", Tensor? host_kv_cache_pool_mapping"
", Tensor? cache_indirection"
", Tensor? kv_scale_orig_quant"
", Tensor? kv_scale_quant_orig"
", Tensor? out_scale"
", Tensor? rotary_inv_freq"
", Tensor? rotary_cos_sin"
", Tensor? latent_cache"
", Tensor? q_pe"
", Tensor? block_ids_per_seq"
", Tensor? attention_sinks"
", bool is_fused_qkv"
", bool update_kv_cache"
", int predicted_tokens_per_seq"
", int layer_idx"
", int num_heads"
", int num_kv_heads"
", int head_size"
", SymInt? tokens_per_block"
", SymInt max_num_requests"
", SymInt max_context_length"
", SymInt attention_window_size"
", int sink_token_length"
", int beam_width"
", int mask_type"
", int quant_mode"
", float q_scaling"
", int position_embedding_type"
", int rotary_embedding_dim"
", float rotary_embedding_base"
", int rotary_embedding_scale_type"
", float[] rotary_embedding_scales"
", int[] rotary_embedding_max_position_info"
", bool use_paged_context_fmha"
", int? attention_input_type"
", bool is_mla_enable"
", int? q_lora_rank"
", int? kv_lora_rank"
", int? qk_nope_head_dim"
", int? qk_rope_head_dim"
", int? v_head_dim"
", Tensor? mrope_rotary_cos_sin"
", Tensor? mrope_position_deltas"
", Tensor? mla_context_paged_kv"
", Tensor? mla_context_kv_cache_block_offsets"
", int? attention_chunk_size"
", Tensor? softmax_stats_tensor"
", bool[] spec_decoding_bool_params"
", Tensor?[] spec_decoding_tensor_params"
") -> ()");
m.def("attention_supports_nvfp4_output", &torch_ext::attention_supports_nvfp4_output);
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("attention_inplace", &torch_ext::attention_inplace);
}

View File

@ -0,0 +1,63 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <optional>
#include <torch/extension.h>
namespace torch_ext
{
/**
* @brief Attention operation for TensorRT-LLM
*
* This function performs multi-head attention computation in-place, supporting both
* context and generation phases with various optimization features including:
* - Fused QKV processing
* - KV cache management
* - Multiple position embedding types (RoPE, ALiBi, etc.)
* - Quantization support (FP8, FP4, etc.)
* - Multi-layer attention (MLA)
* - Speculative decoding
*/
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types,
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params);
} // namespace torch_ext

View File

@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
import torch
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.internal import thop
from tensorrt_llm.functional import AttentionMaskType
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig
@ -419,7 +420,7 @@ class TrtllmAttentionWrapper:
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
]
torch.ops.trtllm.attention_inplace(
thop.attention(
q,
k,
v,