mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-6906][chore] Using pybind to bind functions in thop/attentionOp (#6745)
Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
This commit is contained in:
parent
27fc35175e
commit
f7c13a4aa7
@ -17,6 +17,7 @@ set(SRCS
|
||||
testing/modelSpecBinding.cpp
|
||||
runtime/moeBindings.cpp
|
||||
userbuffers/bindings.cpp
|
||||
thop/bindings.cpp
|
||||
../runtime/ipcNvlsMemory.cu
|
||||
bindings.cpp)
|
||||
|
||||
@ -42,7 +43,8 @@ target_link_libraries(
|
||||
${Python3_LIBRARIES}
|
||||
${TORCH_LIBRARIES}
|
||||
torch_python
|
||||
${CUDA_NVML_LIB})
|
||||
${CUDA_NVML_LIB}
|
||||
th_common)
|
||||
target_compile_definitions(
|
||||
${TRTLLM_NB_MODULE} PUBLIC TRTLLM_NB_MODULE=${TRTLLM_NB_MODULE}
|
||||
PYBIND11_DETAILED_ERROR_MESSAGES=1)
|
||||
|
||||
@ -39,6 +39,7 @@
|
||||
#include "tensorrt_llm/nanobind/executor/bindings.h"
|
||||
#include "tensorrt_llm/nanobind/runtime/bindings.h"
|
||||
#include "tensorrt_llm/nanobind/testing/modelSpecBinding.h"
|
||||
#include "tensorrt_llm/nanobind/thop/bindings.h"
|
||||
#include "tensorrt_llm/nanobind/userbuffers/bindings.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
@ -124,9 +125,11 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
|
||||
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
|
||||
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
|
||||
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
|
||||
auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings");
|
||||
|
||||
tensorrt_llm::nanobind::executor::initBindings(mExecutor);
|
||||
tensorrt_llm::nanobind::runtime::initBindingsEarly(mInternalRuntime);
|
||||
tensorrt_llm::nanobind::thop::initBindings(mInternalThop);
|
||||
|
||||
auto buildInfo = m.def_submodule("BuildInfo");
|
||||
buildInfo.attr("ENABLE_MULTI_DEVICE") = nb::int_(ENABLE_MULTI_DEVICE);
|
||||
|
||||
59
cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Normal file
59
cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Normal file
@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include <tensorrt_llm/thop/attentionOp.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace tensorrt_llm::nanobind::thop
|
||||
{
|
||||
|
||||
void initBindings(nb::module_& m)
|
||||
{
|
||||
m.def("attention", &torch_ext::attention,
|
||||
// Parameters with default values using std::nullopt for optional arguments
|
||||
nb::arg("q"), nb::arg("k") = std::nullopt, nb::arg("v") = std::nullopt, nb::arg("output"),
|
||||
nb::arg("output_sf") = std::nullopt, nb::arg("out_dtype") = std::nullopt, nb::arg("workspace_") = std::nullopt,
|
||||
nb::arg("sequence_length"), nb::arg("host_past_key_value_lengths"), nb::arg("context_lengths"),
|
||||
nb::arg("host_context_lengths"), nb::arg("host_request_types"),
|
||||
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt,
|
||||
nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt,
|
||||
nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt,
|
||||
nb::arg("kv_scale_quant_orig") = std::nullopt, nb::arg("out_scale") = std::nullopt,
|
||||
nb::arg("rotary_inv_freq") = std::nullopt, nb::arg("rotary_cos_sin") = std::nullopt,
|
||||
nb::arg("latent_cache") = std::nullopt, nb::arg("q_pe") = std::nullopt,
|
||||
nb::arg("block_ids_per_seq") = std::nullopt, nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"),
|
||||
nb::arg("update_kv_cache"), nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"),
|
||||
nb::arg("num_kv_heads"), nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt,
|
||||
nb::arg("max_num_requests"), nb::arg("max_context_length"), nb::arg("attention_window_size"),
|
||||
nb::arg("sink_token_length"), nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"),
|
||||
nb::arg("q_scaling"), nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"),
|
||||
nb::arg("rotary_embedding_base"), nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
|
||||
nb::arg("rotary_embedding_max_position_info"), nb::arg("use_paged_context_fmha"),
|
||||
nb::arg("attention_input_type") = std::nullopt, nb::arg("is_mla_enable"), nb::arg("q_lora_rank") = std::nullopt,
|
||||
nb::arg("kv_lora_rank") = std::nullopt, nb::arg("qk_nope_head_dim") = std::nullopt,
|
||||
nb::arg("qk_rope_head_dim") = std::nullopt, nb::arg("v_head_dim") = std::nullopt,
|
||||
nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt,
|
||||
nb::arg("mla_context_paged_kv") = std::nullopt, nb::arg("mla_context_kv_cache_block_offsets") = std::nullopt,
|
||||
nb::arg("attention_chunk_size") = std::nullopt, nb::arg("softmax_stats_tensor") = std::nullopt,
|
||||
nb::arg("spec_decoding_bool_params"), nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation");
|
||||
}
|
||||
} // namespace tensorrt_llm::nanobind::thop
|
||||
27
cpp/tensorrt_llm/nanobind/thop/bindings.h
Normal file
27
cpp/tensorrt_llm/nanobind/thop/bindings.h
Normal file
@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/nanobind/common/customCasters.h"
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
namespace tensorrt_llm::nanobind::thop
|
||||
{
|
||||
|
||||
void initBindings(nb::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::nanobind::thop
|
||||
@ -18,6 +18,7 @@ set(SRCS
|
||||
runtime/moeBindings.cpp
|
||||
userbuffers/bindings.cpp
|
||||
../runtime/ipcNvlsMemory.cu
|
||||
thop/bindings.cpp
|
||||
bindings.cpp)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
@ -43,7 +44,8 @@ target_link_libraries(
|
||||
${Python3_LIBRARIES}
|
||||
${TORCH_LIBRARIES}
|
||||
torch_python
|
||||
${CUDA_NVML_LIB})
|
||||
${CUDA_NVML_LIB}
|
||||
th_common)
|
||||
target_compile_definitions(
|
||||
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
|
||||
PYBIND11_DETAILED_ERROR_MESSAGES=1)
|
||||
|
||||
@ -33,6 +33,7 @@
|
||||
#include "tensorrt_llm/pybind/executor/bindings.h"
|
||||
#include "tensorrt_llm/pybind/runtime/bindings.h"
|
||||
#include "tensorrt_llm/pybind/testing/modelSpecBinding.h"
|
||||
#include "tensorrt_llm/pybind/thop/bindings.h"
|
||||
#include "tensorrt_llm/pybind/userbuffers/bindings.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
@ -116,9 +117,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
|
||||
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
|
||||
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
|
||||
auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings");
|
||||
|
||||
tensorrt_llm::pybind::executor::initBindings(mExecutor);
|
||||
tensorrt_llm::pybind::runtime::initBindingsEarly(mInternalRuntime);
|
||||
tensorrt_llm::pybind::thop::initBindings(mInternalThop);
|
||||
|
||||
auto buildInfo = m.def_submodule("BuildInfo");
|
||||
buildInfo.attr("ENABLE_MULTI_DEVICE") = py::int_(ENABLE_MULTI_DEVICE);
|
||||
|
||||
59
cpp/tensorrt_llm/pybind/thop/bindings.cpp
Normal file
59
cpp/tensorrt_llm/pybind/thop/bindings.cpp
Normal file
@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <tensorrt_llm/thop/attentionOp.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::pybind::thop
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m)
|
||||
{
|
||||
m.def("attention", &torch_ext::attention,
|
||||
// Parameters with default values using std::nullopt for optional arguments
|
||||
py::arg("q"), py::arg("k") = std::nullopt, py::arg("v") = std::nullopt, py::arg("output"),
|
||||
py::arg("output_sf") = std::nullopt, py::arg("out_dtype") = std::nullopt, py::arg("workspace_") = std::nullopt,
|
||||
py::arg("sequence_length"), py::arg("host_past_key_value_lengths"), py::arg("context_lengths"),
|
||||
py::arg("host_context_lengths"), py::arg("host_request_types"),
|
||||
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt,
|
||||
py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt,
|
||||
py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt,
|
||||
py::arg("kv_scale_quant_orig") = std::nullopt, py::arg("out_scale") = std::nullopt,
|
||||
py::arg("rotary_inv_freq") = std::nullopt, py::arg("rotary_cos_sin") = std::nullopt,
|
||||
py::arg("latent_cache") = std::nullopt, py::arg("q_pe") = std::nullopt,
|
||||
py::arg("block_ids_per_seq") = std::nullopt, py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"),
|
||||
py::arg("update_kv_cache"), py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"),
|
||||
py::arg("num_kv_heads"), py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt,
|
||||
py::arg("max_num_requests"), py::arg("max_context_length"), py::arg("attention_window_size"),
|
||||
py::arg("sink_token_length"), py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"),
|
||||
py::arg("q_scaling"), py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"),
|
||||
py::arg("rotary_embedding_base"), py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
|
||||
py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"),
|
||||
py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"), py::arg("q_lora_rank") = std::nullopt,
|
||||
py::arg("kv_lora_rank") = std::nullopt, py::arg("qk_nope_head_dim") = std::nullopt,
|
||||
py::arg("qk_rope_head_dim") = std::nullopt, py::arg("v_head_dim") = std::nullopt,
|
||||
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
|
||||
py::arg("mla_context_paged_kv") = std::nullopt, py::arg("mla_context_kv_cache_block_offsets") = std::nullopt,
|
||||
py::arg("attention_chunk_size") = std::nullopt, py::arg("softmax_stats_tensor") = std::nullopt,
|
||||
py::arg("spec_decoding_bool_params"), py::arg("spec_decoding_tensor_params"), "Multi-head attention operation");
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::thop
|
||||
27
cpp/tensorrt_llm/pybind/thop/bindings.h
Normal file
27
cpp/tensorrt_llm/pybind/thop/bindings.h
Normal file
@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::thop
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::thop
|
||||
@ -21,6 +21,7 @@
|
||||
#include "tensorrt_llm/kernels/mlaKernels.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/debugUtils.h"
|
||||
#include "tensorrt_llm/thop/attentionOp.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
@ -420,32 +421,31 @@ using RunnerPtr = std::shared_ptr<torch_ext::trtllm::attention::RunnerBase>;
|
||||
using torch_ext::trtllm::attention::Runner;
|
||||
using torch_ext::trtllm::attention::AttentionInputType;
|
||||
|
||||
void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v,
|
||||
torch::Tensor& output, torch::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
|
||||
torch::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
|
||||
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
|
||||
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
|
||||
torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types,
|
||||
torch::optional<torch::Tensor> kv_cache_block_offsets, torch::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
torch::optional<torch::Tensor> host_kv_cache_pool_pointers,
|
||||
torch::optional<torch::Tensor> host_kv_cache_pool_mapping, torch::optional<torch::Tensor> cache_indirection,
|
||||
torch::optional<torch::Tensor> kv_scale_orig_quant, torch::optional<torch::Tensor> kv_scale_quant_orig,
|
||||
torch::optional<torch::Tensor> out_scale, torch::optional<torch::Tensor> rotary_inv_freq,
|
||||
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
|
||||
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
|
||||
torch::optional<torch::Tensor> attention_sinks, bool const is_fused_qkv, bool const update_kv_cache,
|
||||
int64_t const predicted_tokens_per_seq, int64_t const layer_idx, int64_t const num_heads,
|
||||
int64_t const num_kv_heads, int64_t const head_size, std::optional<int64_t> const tokens_per_block,
|
||||
int64_t const max_num_requests, int64_t const max_context_length, int64_t const attention_window_size,
|
||||
int64_t const sink_token_length, int64_t const beam_width, int64_t const mask_type, int64_t const quant_mode,
|
||||
double const q_scaling, int64_t const position_embedding_type, int64_t const rotary_embedding_dim,
|
||||
double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
c10::ArrayRef<double> rotary_embedding_scales, c10::ArrayRef<int64_t> rotary_embedding_max_position_info,
|
||||
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
|
||||
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
|
||||
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
|
||||
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
|
||||
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
|
||||
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
|
||||
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
|
||||
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
|
||||
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
|
||||
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
|
||||
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
|
||||
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
|
||||
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
|
||||
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
|
||||
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
|
||||
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
|
||||
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
|
||||
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
|
||||
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
|
||||
c10::List<bool> spec_decoding_bool_params, c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params)
|
||||
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params)
|
||||
{
|
||||
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
|
||||
// Use these tensors to infer if the attention is using KV cache
|
||||
@ -751,78 +751,5 @@ bool attention_supports_nvfp4_output(int64_t const num_heads, int64_t const num_
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"attention_inplace("
|
||||
"Tensor q"
|
||||
", Tensor? k"
|
||||
", Tensor? v"
|
||||
", Tensor(a!) output"
|
||||
", Tensor(b!)? output_sf"
|
||||
", ScalarType? out_dtype"
|
||||
", Tensor? workspace"
|
||||
", Tensor sequence_length"
|
||||
", Tensor host_past_key_value_lengths"
|
||||
", Tensor context_lengths"
|
||||
", Tensor host_context_lengths"
|
||||
", Tensor host_request_types"
|
||||
", Tensor? kv_cache_block_offsets"
|
||||
", Tensor? host_kv_cache_block_offsets"
|
||||
", Tensor? host_kv_cache_pool_pointers"
|
||||
", Tensor? host_kv_cache_pool_mapping"
|
||||
", Tensor? cache_indirection"
|
||||
", Tensor? kv_scale_orig_quant"
|
||||
", Tensor? kv_scale_quant_orig"
|
||||
", Tensor? out_scale"
|
||||
", Tensor? rotary_inv_freq"
|
||||
", Tensor? rotary_cos_sin"
|
||||
", Tensor? latent_cache"
|
||||
", Tensor? q_pe"
|
||||
", Tensor? block_ids_per_seq"
|
||||
", Tensor? attention_sinks"
|
||||
", bool is_fused_qkv"
|
||||
", bool update_kv_cache"
|
||||
", int predicted_tokens_per_seq"
|
||||
", int layer_idx"
|
||||
", int num_heads"
|
||||
", int num_kv_heads"
|
||||
", int head_size"
|
||||
", SymInt? tokens_per_block"
|
||||
", SymInt max_num_requests"
|
||||
", SymInt max_context_length"
|
||||
", SymInt attention_window_size"
|
||||
", int sink_token_length"
|
||||
", int beam_width"
|
||||
", int mask_type"
|
||||
", int quant_mode"
|
||||
", float q_scaling"
|
||||
", int position_embedding_type"
|
||||
", int rotary_embedding_dim"
|
||||
", float rotary_embedding_base"
|
||||
", int rotary_embedding_scale_type"
|
||||
", float[] rotary_embedding_scales"
|
||||
", int[] rotary_embedding_max_position_info"
|
||||
", bool use_paged_context_fmha"
|
||||
", int? attention_input_type"
|
||||
", bool is_mla_enable"
|
||||
", int? q_lora_rank"
|
||||
", int? kv_lora_rank"
|
||||
", int? qk_nope_head_dim"
|
||||
", int? qk_rope_head_dim"
|
||||
", int? v_head_dim"
|
||||
", Tensor? mrope_rotary_cos_sin"
|
||||
", Tensor? mrope_position_deltas"
|
||||
", Tensor? mla_context_paged_kv"
|
||||
", Tensor? mla_context_kv_cache_block_offsets"
|
||||
", int? attention_chunk_size"
|
||||
", Tensor? softmax_stats_tensor"
|
||||
", bool[] spec_decoding_bool_params"
|
||||
", Tensor?[] spec_decoding_tensor_params"
|
||||
") -> ()");
|
||||
|
||||
m.def("attention_supports_nvfp4_output", &torch_ext::attention_supports_nvfp4_output);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("attention_inplace", &torch_ext::attention_inplace);
|
||||
}
|
||||
|
||||
63
cpp/tensorrt_llm/thop/attentionOp.h
Normal file
63
cpp/tensorrt_llm/thop/attentionOp.h
Normal file
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
/**
|
||||
* @brief Attention operation for TensorRT-LLM
|
||||
*
|
||||
* This function performs multi-head attention computation in-place, supporting both
|
||||
* context and generation phases with various optimization features including:
|
||||
* - Fused QKV processing
|
||||
* - KV cache management
|
||||
* - Multiple position embedding types (RoPE, ALiBi, etc.)
|
||||
* - Quantization support (FP8, FP4, etc.)
|
||||
* - Multi-layer attention (MLA)
|
||||
* - Speculative decoding
|
||||
*/
|
||||
void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<torch::Tensor> v, torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_sf, std::optional<torch::ScalarType> out_dtype,
|
||||
std::optional<torch::Tensor> workspace_, torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths,
|
||||
torch::Tensor context_lengths, torch::Tensor host_context_lengths, torch::Tensor host_request_types,
|
||||
std::optional<torch::Tensor> kv_cache_block_offsets, std::optional<torch::Tensor> host_kv_cache_block_offsets,
|
||||
std::optional<torch::Tensor> host_kv_cache_pool_pointers, std::optional<torch::Tensor> host_kv_cache_pool_mapping,
|
||||
std::optional<torch::Tensor> cache_indirection, std::optional<torch::Tensor> kv_scale_orig_quant,
|
||||
std::optional<torch::Tensor> kv_scale_quant_orig, std::optional<torch::Tensor> out_scale,
|
||||
std::optional<torch::Tensor> rotary_inv_freq, std::optional<torch::Tensor> rotary_cos_sin,
|
||||
std::optional<torch::Tensor> latent_cache, std::optional<torch::Tensor> q_pe,
|
||||
std::optional<torch::Tensor> block_ids_per_seq, std::optional<torch::Tensor> attention_sinks,
|
||||
bool const is_fused_qkv, bool const update_kv_cache, int64_t const predicted_tokens_per_seq,
|
||||
int64_t const layer_idx, int64_t const num_heads, int64_t const num_kv_heads, int64_t const head_size,
|
||||
std::optional<int64_t> const tokens_per_block, int64_t const max_num_requests, int64_t const max_context_length,
|
||||
int64_t const attention_window_size, int64_t const sink_token_length, int64_t const beam_width,
|
||||
int64_t const mask_type, int64_t const quant_mode, double const q_scaling, int64_t const position_embedding_type,
|
||||
int64_t const rotary_embedding_dim, double const rotary_embedding_base, int64_t const rotary_embedding_scale_type,
|
||||
std::vector<double> rotary_embedding_scales, std::vector<int64_t> rotary_embedding_max_position_info,
|
||||
bool const use_paged_context_fmha, std::optional<int64_t> attention_input_type, bool is_mla_enable,
|
||||
std::optional<int64_t> q_lora_rank, std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
|
||||
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
|
||||
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
|
||||
std::optional<torch::Tensor> mla_context_paged_kv, std::optional<torch::Tensor> mla_context_kv_cache_block_offsets,
|
||||
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
|
||||
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params);
|
||||
|
||||
} // namespace torch_ext
|
||||
@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.bindings.internal import thop
|
||||
from tensorrt_llm.functional import AttentionMaskType
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
@ -419,7 +420,7 @@ class TrtllmAttentionWrapper:
|
||||
self.spec_decoding_position_offsets, self.spec_decoding_packed_mask
|
||||
]
|
||||
|
||||
torch.ops.trtllm.attention_inplace(
|
||||
thop.attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user