mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 18:51:38 +08:00
[None][chore] Removing cpp/tensorrt_llm/pybind (#11026)
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
This commit is contained in:
parent
38bcee189c
commit
29647d9446
@ -1,77 +0,0 @@
|
||||
set(TRTLLM_PYBIND_MODULE bindings)
|
||||
set(TRTLLM_PYBIND_MODULE
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PARENT_SCOPE)
|
||||
|
||||
set(SRCS
|
||||
batch_manager/algorithms.cpp
|
||||
batch_manager/bindings.cpp
|
||||
batch_manager/buffers.cpp
|
||||
batch_manager/cacheTransceiver.cpp
|
||||
batch_manager/kvCacheConnector.cpp
|
||||
batch_manager/kvCacheManager.cpp
|
||||
batch_manager/kvCacheManagerV2Utils.cpp
|
||||
batch_manager/llmRequest.cpp
|
||||
executor/bindings.cpp
|
||||
executor/executor.cpp
|
||||
executor/executorConfig.cpp
|
||||
executor/request.cpp
|
||||
process_group/bindings.cpp
|
||||
runtime/bindings.cpp
|
||||
runtime/hostfunc.cpp
|
||||
common/tllmExceptions.cpp
|
||||
testing/modelSpecBinding.cpp
|
||||
runtime/moeBindings.cpp
|
||||
testing/modelSpecBinding.cpp
|
||||
userbuffers/bindings.cpp
|
||||
../runtime/ipcNvlsMemory.cu
|
||||
thop/bindings.cpp
|
||||
bindings.cpp)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
pybind11_add_module(${TRTLLM_PYBIND_MODULE} ${SRCS})
|
||||
|
||||
set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE
|
||||
ON)
|
||||
|
||||
target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC
|
||||
"${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
if(ENABLE_NVSHMEM)
|
||||
target_link_libraries(${TRTLLM_PYBIND_MODULE} PUBLIC nvshmem::nvshmem_host
|
||||
nvshmem::nvshmem_device)
|
||||
endif()
|
||||
|
||||
target_link_libraries(
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PUBLIC ${SHARED_TARGET}
|
||||
${Python3_LIBRARIES}
|
||||
${TORCH_LIBRARIES}
|
||||
torch_python
|
||||
${CUDA_DRV_LIB}
|
||||
${CUDA_NVML_LIB}
|
||||
th_common
|
||||
pg_utils)
|
||||
target_compile_definitions(
|
||||
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
|
||||
PYBIND11_DETAILED_ERROR_MESSAGES=1)
|
||||
|
||||
if(NOT WIN32)
|
||||
set(TRTLLM_PYBIND_MODULE_RPATH_LIST
|
||||
"$ORIGIN/libs" # TRTLLM libraries
|
||||
"$ORIGIN/../../.." # Shared libraries under $PREFIX/lib
|
||||
"$ORIGIN/../nvidia/nccl/lib" # NCCL libraries
|
||||
)
|
||||
set_target_properties(
|
||||
${TRTLLM_PYBIND_MODULE} PROPERTIES BUILD_RPATH
|
||||
"${TRTLLM_PYBIND_MODULE_RPATH_LIST}")
|
||||
set_target_properties(
|
||||
${TRTLLM_PYBIND_MODULE} PROPERTIES LINK_FLAGS
|
||||
"${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
|
||||
endif()
|
||||
|
||||
# Build transfer_agent_binding when building bindings (if NIXL is enabled)
|
||||
if(TARGET ${TRANSFER_AGENT_BINDING_TARGET})
|
||||
add_dependencies(${TRTLLM_PYBIND_MODULE} ${TRANSFER_AGENT_BINDING_TARGET})
|
||||
endif()
|
||||
@ -1,127 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "algorithms.h"
|
||||
#include "tensorrt_llm/batch_manager/allocateKvCache.h"
|
||||
#include "tensorrt_llm/batch_manager/assignReqSeqSlots.h"
|
||||
#include "tensorrt_llm/batch_manager/capacityScheduler.h"
|
||||
#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h"
|
||||
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
|
||||
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
|
||||
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
|
||||
#include "tensorrt_llm/batch_manager/pauseRequests.h"
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
|
||||
#include "tensorrt_llm/runtime/decoderState.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
#include <ATen/core/TensorBody.h>
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
using namespace tensorrt_llm::batch_manager;
|
||||
|
||||
void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::module_& m)
|
||||
{
|
||||
py::class_<CapacityScheduler>(m, CapacityScheduler::name)
|
||||
.def(py::init<SizeType32, executor::CapacitySchedulerPolicy, bool, bool, LlmRequestState, LlmRequestState>(),
|
||||
py::arg("max_num_requests"), py::arg("capacity_scheduler_policy"), py::arg("has_kv_cache_manager"),
|
||||
py::arg("two_step_lookahead") = false,
|
||||
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
|
||||
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
|
||||
"LlmRequestState.GENERATION_COMPLETE"))
|
||||
.def("__call__", &CapacityScheduler::operator(), py::arg("active_requests"),
|
||||
py::arg("kv_cache_manager") = nullptr, py::arg("peft_cache_manager") = nullptr,
|
||||
py::arg("cross_kv_cache_manager") = nullptr)
|
||||
.def("name", [](CapacityScheduler const&) { return CapacityScheduler::name; });
|
||||
|
||||
py::class_<MicroBatchScheduler>(m, MicroBatchScheduler::name)
|
||||
.def(py::init<std::optional<batch_scheduler::ContextChunkingConfig>, std::optional<SizeType32>, LlmRequestState,
|
||||
LlmRequestState>(),
|
||||
py::arg("ctx_chunk_config") = std::nullopt, py::arg("max_context_length") = std::nullopt,
|
||||
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
|
||||
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_TO_COMPLETE,
|
||||
"LlmRequestState.GENERATION_TO_COMPLETE"))
|
||||
.def("__call__", &MicroBatchScheduler::operator(), py::arg("active_requests"), py::arg("inflight_req_ids"),
|
||||
py::arg("max_batch_size_runtime"), py::arg("max_num_tokens_runtime"))
|
||||
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });
|
||||
|
||||
py::class_<PauseRequests>(m, PauseRequests::name)
|
||||
.def(py::init<SizeType32>(), py::arg("max_input_len"))
|
||||
.def("__call__", &PauseRequests::operator(), py::arg("requests_to_pause"), py::arg("inflight_req_ids"),
|
||||
py::arg("req_ids_to_pause"), py::arg("pause_flagged"), py::arg("seq_slot_manager"),
|
||||
py::arg("kv_cache_manager") = std::nullopt, py::arg("cross_kv_cache_manager") = std::nullopt,
|
||||
py::arg("peft_cache_manager") = std::nullopt)
|
||||
.def("name", [](PauseRequests const&) { return PauseRequests::name; });
|
||||
|
||||
py::class_<AssignReqSeqSlots>(m, AssignReqSeqSlots::name)
|
||||
.def(py::init())
|
||||
.def("__call__", &AssignReqSeqSlots::operator(), py::arg("seq_slot_manager"), py::arg("context_requests"),
|
||||
py::arg("generation_requests"))
|
||||
.def("name", [](AssignReqSeqSlots const&) { return AssignReqSeqSlots::name; });
|
||||
|
||||
py::class_<AllocateKvCache>(m, AllocateKvCache::name)
|
||||
.def(py::init(), py::call_guard<py::gil_scoped_release>())
|
||||
.def("__call__", &AllocateKvCache::operator(), py::arg("kv_cache_manager"), py::arg("context_requests"),
|
||||
py::arg("generation_requests"), py::arg("model_config"), py::arg("cross_kv_cache_manager") = std::nullopt,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("name", [](AllocateKvCache const&) { return AllocateKvCache::name; });
|
||||
|
||||
py::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
|
||||
.def(py::init())
|
||||
.def("__call__", &LogitsPostProcessor::operator(), py::arg("decoder_input_buffers"),
|
||||
py::arg("replicate_logits_post_processor"), py::arg("world_config"), py::arg("stream"),
|
||||
py::arg("logits_post_processor_batched") = std::nullopt)
|
||||
.def("name", [](LogitsPostProcessor const&) { return LogitsPostProcessor::name; });
|
||||
|
||||
py::class_<CreateNewDecoderRequests>(m, CreateNewDecoderRequests::name)
|
||||
.def(py::init<bool, bool, bool>(), py::arg("speculative_decoding_fast_logits"),
|
||||
py::arg("is_leader_in_orch_mode"), py::arg("is_normalize_log_probs"))
|
||||
.def(
|
||||
"__call__",
|
||||
[](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig,
|
||||
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
|
||||
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
|
||||
runtime::decoder::DecoderState& decoderState, tensorrt_llm::runtime::CudaStream const& runtimeStream,
|
||||
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
|
||||
SizeType32 beamWidth)
|
||||
{
|
||||
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
|
||||
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
|
||||
= self(modelConfig, worldConfig, decodingConfig, contextRequests, logitsType, inputBuffers,
|
||||
decoderState, runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
|
||||
|
||||
return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs),
|
||||
std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
|
||||
},
|
||||
py::arg("model_config"), py::arg("world_config"), py::arg("decoding_config"), py::arg("context_requests"),
|
||||
py::arg("logits_type"), py::arg("decoder_input_buffers"), py::arg("decoder_state"),
|
||||
py::arg("runtime_stream"), py::arg("decoder_stream"), py::arg("max_sequence_length"), py::arg("beam_width"))
|
||||
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
|
||||
}
|
||||
@ -1,28 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager::algorithms
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m);
|
||||
|
||||
}
|
||||
@ -1,517 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/common.h"
|
||||
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
|
||||
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
|
||||
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
|
||||
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
|
||||
#include "tensorrt_llm/pybind/common/bindTypes.h"
|
||||
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m)
|
||||
{
|
||||
using GenLlmReq = tb::GenericLlmRequest<runtime::ITensor::SharedPtr>;
|
||||
|
||||
// Create and register exceptions in module scope
|
||||
static PyObject* peft_exc = PyErr_NewException(
|
||||
"tensorrt_llm.bindings.internal.batch_manager.PeftTaskNotCachedException", nullptr, nullptr);
|
||||
static PyObject* lora_exc
|
||||
= PyErr_NewException("tensorrt_llm.bindings.internal.batch_manager.LoraCacheFullException", nullptr, nullptr);
|
||||
|
||||
m.add_object("PeftTaskNotCachedException", py::handle(peft_exc));
|
||||
m.add_object("LoraCacheFullException", py::handle(lora_exc));
|
||||
|
||||
// Register with no captures
|
||||
py::register_exception_translator(
|
||||
[](std::exception_ptr p)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (p)
|
||||
std::rethrow_exception(p);
|
||||
}
|
||||
catch (const tb::PeftTaskNotCachedException& e)
|
||||
{
|
||||
PyErr_SetString(peft_exc, e.what());
|
||||
}
|
||||
catch (const tr::LoraCacheFullException& e)
|
||||
{
|
||||
PyErr_SetString(lora_exc, e.what());
|
||||
}
|
||||
});
|
||||
|
||||
PybindUtils::bindSet<tb::ReqIdsSet>(m, "ReqIdsSet");
|
||||
|
||||
py::enum_<tb::LlmRequestType>(m, "LlmRequestType")
|
||||
.value("LLMREQUEST_TYPE_CONTEXT_AND_GENERATION", tb::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION)
|
||||
.value("LLMREQUEST_TYPE_CONTEXT_ONLY", tb::LLMREQUEST_TYPE_CONTEXT_ONLY)
|
||||
.value("LLMREQUEST_TYPE_GENERATION_ONLY", tb::LLMREQUEST_TYPE_GENERATION_ONLY)
|
||||
.export_values();
|
||||
|
||||
py::class_<tb::batch_scheduler::ContextChunkingConfig>(m, "ContextChunkingConfig")
|
||||
.def(py::init<tle::ContextChunkingPolicy, tensorrt_llm::runtime::SizeType32>(), py::arg("chunking_policy"),
|
||||
py::arg("chunk_unit_size"))
|
||||
.def_readwrite("chunking_policy", &tb::batch_scheduler::ContextChunkingConfig::chunkingPolicy)
|
||||
.def_readwrite("chunk_unit_size", &tb::batch_scheduler::ContextChunkingConfig::chunkUnitSize);
|
||||
|
||||
py::classh<GenLlmReq>(m, "GenericLlmRequest")
|
||||
.def("set_exclude_input_from_output", &GenLlmReq::setExcludeInputFromOutput, py::arg("exclude"))
|
||||
.def("get_num_tokens", &GenLlmReq::getNumTokens, py::arg("beam"))
|
||||
.def_property_readonly("max_beam_num_tokens", &GenLlmReq::getMaxBeamNumTokens)
|
||||
.def("get_token", &GenLlmReq::getToken, py::arg("beam"), py::arg("pos"))
|
||||
.def("get_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getTokens, py::const_), py::arg("beam"))
|
||||
.def("get_tokens", py::overload_cast<>(&GenLlmReq::getTokens, py::const_))
|
||||
.def("get_last_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getLastTokens), py::arg("beam"))
|
||||
.def("get_last_tokens", py::overload_cast<>(&GenLlmReq::getLastTokens))
|
||||
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, py::arg("for_next_iteration") = false)
|
||||
.def_property_readonly("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
|
||||
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
|
||||
.def("add_new_token", &GenLlmReq::addNewToken, py::arg("token"), py::arg("beam"))
|
||||
.def("add_new_tokens", &GenLlmReq::addNewTokens, py::arg("beam_tokens"))
|
||||
.def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens)
|
||||
.def("set_generated_tokens", &GenLlmReq::setGeneratedTokens, py::arg("generated_beam_tokens"))
|
||||
.def("pause", &GenLlmReq::pause, py::arg("max_input_len"))
|
||||
.def_property("max_sent_token_len", &GenLlmReq::getMaxSentTokenLen, &GenLlmReq::setMaxSentTokenLen)
|
||||
.def_property_readonly("prompt_embedding_table", &GenLlmReq::getPromptEmbeddingTable)
|
||||
.def_property_readonly("multimodal_embedding", &GenLlmReq::getMultimodalEmbedding)
|
||||
.def_property_readonly("mrope_rotary_cos_sin", &GenLlmReq::getMropeRotaryCosSin)
|
||||
.def_property_readonly("bad_words_list", &GenLlmReq::getBadWordsList)
|
||||
.def_property("draft_logits", &GenLlmReq::getDraftLogits, &GenLlmReq::setDraftLogits)
|
||||
.def_property_readonly("embedding_bias", &GenLlmReq::getEmbeddingBias)
|
||||
.def_property("lora_config", &GenLlmReq::getLoraConfig, &GenLlmReq::setLoraConfig)
|
||||
.def_property("lora_weights", &GenLlmReq::getLoraWeights, &GenLlmReq::setLoraWeights)
|
||||
.def_property_readonly("stop_words_list", &GenLlmReq::getStopWordsList)
|
||||
.def_property_readonly("context_logits", &GenLlmReq::getContextLogitsHost)
|
||||
.def_property_readonly("generation_logits", &GenLlmReq::getGenerationLogitsHost)
|
||||
.def_property_readonly("prompt_vocab_size", &GenLlmReq::getPromptVocabSize)
|
||||
.def_property_readonly("mrope_position_deltas", &GenLlmReq::getMropePositionDeltas)
|
||||
.def_property_readonly("lora_task_id", &GenLlmReq::getLoraTaskId)
|
||||
.def_property_readonly("lookahead_config", &GenLlmReq::getLookaheadConfig)
|
||||
.def_property("context_chunk_size", &GenLlmReq::getContextChunkSize, &GenLlmReq::setContextChunkSize)
|
||||
.def_property("decoding_iter", &GenLlmReq::getDecodingIter, &GenLlmReq::setDecodingIter)
|
||||
.def_readwrite("request_id", &GenLlmReq::mRequestId)
|
||||
.def_readwrite("prompt_len", &GenLlmReq::mPromptLen)
|
||||
.def_readwrite("max_new_tokens", &GenLlmReq::mMaxNewTokens)
|
||||
.def_readwrite("sampling_config", &GenLlmReq::mSamplingConfig)
|
||||
.def_property("state", &GenLlmReq::getState, &GenLlmReq::setState)
|
||||
.def_property_readonly("state_value", [](GenLlmReq const& self) { return static_cast<int>(self.getState()); })
|
||||
.def_property("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming)
|
||||
.def_readwrite("end_id", &GenLlmReq::mEndId)
|
||||
.def_readwrite("pad_id", &GenLlmReq::mPadId)
|
||||
.def_readwrite("seq_slot", &GenLlmReq::mSeqSlot)
|
||||
.def_property_readonly("return_log_probs", &GenLlmReq::returnLogProbs)
|
||||
.def_property_readonly("return_context_logits", &GenLlmReq::getReturnContextLogits)
|
||||
.def_property_readonly("return_generation_logits", &GenLlmReq::getReturnGenerationLogits)
|
||||
.def_property_readonly("log_probs", py::overload_cast<>(&GenLlmReq::getLogProbs, py::const_))
|
||||
.def("get_log_probs", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getLogProbs, py::const_))
|
||||
.def("set_log_probs", &GenLlmReq::setLogProbs, py::arg("log_probs"), py::arg("beam"))
|
||||
.def("set_return_encoder_output", &GenLlmReq::setReturnEncoderOutput, py::arg("return_encoder_output"))
|
||||
.def("get_return_encoder_output", &GenLlmReq::getReturnEncoderOutput)
|
||||
.def("priority", py::overload_cast<>(&GenLlmReq::priority, py::const_))
|
||||
.def("set_priority", py::overload_cast<tle::PriorityType>(&GenLlmReq::setPriority))
|
||||
.def_property_readonly("cum_log_probs", &GenLlmReq::getCumLogProbs)
|
||||
.def("set_cum_log_prob", &GenLlmReq::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
|
||||
.def("update_num_tokens_per_iteration", &GenLlmReq::updateNumTokensPerIteration,
|
||||
py::arg("num_tokens_per_iteration"), py::arg("model_config"))
|
||||
.def_property_readonly("orig_prompt_len", &GenLlmReq::getOrigPromptLen)
|
||||
.def("has_draft_tokens", &GenLlmReq::hasDraftTokens)
|
||||
.def("move_to_next_context_chunk", &GenLlmReq::moveToNextContextChunk)
|
||||
.def_property_readonly("is_last_context_chunk", &GenLlmReq::isLastContextChunk)
|
||||
.def_property_readonly("is_first_context_chunk", &GenLlmReq::isFirstContextChunk)
|
||||
.def_property_readonly("context_remaining_length", &GenLlmReq::getContextRemainingLength)
|
||||
.def_property_readonly("context_logits", &GenLlmReq::getContextLogitsHost)
|
||||
.def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens)
|
||||
.def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam"))
|
||||
.def_property_readonly("is_finished", &GenLlmReq::isFinished)
|
||||
.def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
|
||||
.def_property_readonly("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
|
||||
.def_property(
|
||||
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
|
||||
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
|
||||
.def_property(
|
||||
"guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
|
||||
.def_property("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
|
||||
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
|
||||
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
|
||||
.def_property_readonly("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
|
||||
.def_property_readonly("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
|
||||
.def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished)
|
||||
.def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)
|
||||
.def_property_readonly(
|
||||
"is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
|
||||
.def_property_readonly(
|
||||
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
|
||||
.def_property_readonly("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
|
||||
.def_property_readonly("is_context_init_state", &GenLlmReq::isContextInitState)
|
||||
.def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
|
||||
.def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
|
||||
.def_property_readonly("is_disagg_context_complete_state", &GenLlmReq::isDisaggContextCompleteState)
|
||||
.def_property_readonly("stage", &GenLlmReq::getRequestStage)
|
||||
.def_property_readonly("kv_cache_transfer_time_ms", &GenLlmReq::getKvCacheTransferTimeMS)
|
||||
.def_property_readonly("kv_cache_size", &GenLlmReq::getKvCacheSize)
|
||||
.def_property_readonly("avg_decoded_tokens_per_iter", &GenLlmReq::getAvgDecodedTokensPerIter)
|
||||
.def_property_readonly("alloc_total_blocks", &GenLlmReq::getAllocTotalBlocksPerRequest)
|
||||
.def_property_readonly("alloc_new_blocks", &GenLlmReq::getAllocNewBlocksPerRequest)
|
||||
.def("alloc_context_logits", &GenLlmReq::allocContextLogitsHost, py::arg("vocab_size"), py::arg("logit_dtype"))
|
||||
.def_property_readonly("reused_blocks", &GenLlmReq::getReusedBlocksPerRequest)
|
||||
.def_property_readonly("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest)
|
||||
.def_property_readonly("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest)
|
||||
.def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType)
|
||||
.def_property_readonly("parent_request_id", &GenLlmReq::getParentRequestId)
|
||||
.def_property_readonly("is_child", &GenLlmReq::isChild)
|
||||
.def_property_readonly("cache_salt_id", &GenLlmReq::getCacheSaltID)
|
||||
.def_property_readonly("multimodal_hashes",
|
||||
[](GenLlmReq& self)
|
||||
{
|
||||
std::optional<std::vector<std::vector<GenLlmReq::SizeType32>>> hashes = std::nullopt;
|
||||
if (self.getMultimodalHashes())
|
||||
{
|
||||
hashes = *self.getMultimodalHashes().value();
|
||||
}
|
||||
return hashes;
|
||||
})
|
||||
.def_property_readonly("multimodal_positions",
|
||||
[](GenLlmReq& self)
|
||||
{
|
||||
std::optional<std::vector<GenLlmReq::SizeType32>> positions = std::nullopt;
|
||||
if (self.getMultimodalPositions())
|
||||
{
|
||||
positions = *self.getMultimodalPositions().value();
|
||||
}
|
||||
return positions;
|
||||
})
|
||||
.def_property_readonly("multimodal_lengths",
|
||||
[](GenLlmReq& self)
|
||||
{
|
||||
std::optional<std::vector<GenLlmReq::SizeType32>> lengths = std::nullopt;
|
||||
if (self.getMultimodalLengths())
|
||||
{
|
||||
lengths = *self.getMultimodalLengths().value();
|
||||
}
|
||||
return lengths;
|
||||
})
|
||||
.def_property_readonly("position_ids",
|
||||
[](GenLlmReq& self)
|
||||
{
|
||||
std::optional<std::vector<GenLlmReq::SizeType32>> positionIds = std::nullopt;
|
||||
if (self.getPositionIds())
|
||||
{
|
||||
positionIds = *self.getPositionIds().value();
|
||||
}
|
||||
return positionIds;
|
||||
})
|
||||
.def_property(
|
||||
"draft_tokens",
|
||||
[](GenLlmReq& self)
|
||||
{
|
||||
std::optional<GenLlmReq::VecTokens> draftTokens = std::nullopt;
|
||||
if (self.hasDraftTokens())
|
||||
{
|
||||
draftTokens = *self.getDraftTokens();
|
||||
}
|
||||
return draftTokens;
|
||||
},
|
||||
[](GenLlmReq& self, std::optional<GenLlmReq::VecTokens> const& draftTokens)
|
||||
{
|
||||
if (draftTokens)
|
||||
{
|
||||
self.setDraftTokens(std::make_shared<GenLlmReq::VecTokens>(draftTokens.value()));
|
||||
}
|
||||
})
|
||||
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
|
||||
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
|
||||
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
|
||||
.def("get_unique_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, py::const_),
|
||||
py::arg("beam"))
|
||||
.def("get_unique_tokens", py::overload_cast<>(&GenLlmReq::getUniqueTokens, py::const_))
|
||||
.def("get_encoder_unique_tokens",
|
||||
[](GenLlmReq& self)
|
||||
{
|
||||
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
|
||||
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
|
||||
{
|
||||
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
|
||||
}
|
||||
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
|
||||
});
|
||||
|
||||
py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
|
||||
.def(py::init<>(
|
||||
[](tb::LlmRequest::RequestIdType request_id, tb::LlmRequest::SizeType32 max_new_tokens,
|
||||
std::vector<tb::LlmRequest::TokenIdType> input_tokens, runtime::SamplingConfig sampling_config,
|
||||
bool is_streaming, std::optional<tb::LlmRequest::SizeType32> end_id,
|
||||
std::optional<tb::LlmRequest::SizeType32> pad_id, std::optional<at::Tensor> embedding_bias,
|
||||
std::optional<at::Tensor> bad_words_list, std::optional<at::Tensor> stop_words_list,
|
||||
std::optional<std::vector<tb::LlmRequest::SizeType32>> position_ids,
|
||||
std::optional<at::Tensor> prompt_embedding_table,
|
||||
std::optional<tb::LlmRequest::SizeType32> prompt_vocab_size,
|
||||
std::optional<std::vector<std::vector<tb::LlmRequest::SizeType32>>> multimodal_hashes,
|
||||
std::optional<std::vector<tb::LlmRequest::SizeType32>> multimodal_positions,
|
||||
std::optional<std::vector<tb::LlmRequest::SizeType32>> multimodal_lengths,
|
||||
std::optional<at::Tensor> multimodal_embedding, std::optional<at::Tensor> mrope_rotary_cos_sin,
|
||||
std::optional<tb::LlmRequest::SizeType32> mrope_position_deltas,
|
||||
std::optional<LoraTaskIdType> lora_task_id, std::optional<at::Tensor> lora_weights,
|
||||
std::optional<at::Tensor> lora_config,
|
||||
std::optional<executor::LookaheadDecodingConfig> lookahead_config,
|
||||
std::optional<executor::KvCacheRetentionConfig> kv_cache_retention_config, bool return_log_probs,
|
||||
bool return_context_logits, bool return_generation_logits,
|
||||
std::optional<tb::LlmRequest::VecTokens> draft_tokens, std::optional<at::Tensor> draft_logits,
|
||||
bool exclude_input_from_output,
|
||||
std::optional<tb::LlmRequest::LogitsPostProcessor> logits_post_processor,
|
||||
bool apply_logits_post_processor_batched,
|
||||
std::optional<tb::LlmRequest::VecTokens> encoder_input_tokens, bool return_encoder_output,
|
||||
std::optional<tb::LlmRequest::RequestIdType> client_id, executor::PriorityType priority,
|
||||
std::optional<at::Tensor> encoder_input_features,
|
||||
std::optional<tb::LlmRequest::SizeType32> encoder_output_length,
|
||||
std::optional<at::Tensor> cross_attention_mask, tb::LlmRequestType llm_request_type,
|
||||
std::optional<tb::LlmRequest::VecTokenExtraIds> input_token_extra_ids,
|
||||
tb::LlmRequest::SizeType32 num_return_sequences, std::optional<executor::EagleConfig> eagle_config,
|
||||
std::optional<at::Tensor> skip_cross_attn_blocks, bool return_perf_metrics,
|
||||
std::optional<executor::GuidedDecodingParams> guided_decoding_params,
|
||||
std::optional<tb::LlmRequest::SizeType32> language_adapter_uid,
|
||||
std::optional<tb::LlmRequest::MillisecondsType> allotted_time_ms,
|
||||
std::optional<executor::ContextPhaseParams> context_phase_params,
|
||||
std::optional<tb::LlmRequest::CacheSaltIDType> cache_salt_id,
|
||||
std::optional<tb::LlmRequest::TimePoint> arrival_time)
|
||||
{
|
||||
auto makeOptionalTensor = [](std::optional<at::Tensor> const& atTensor, bool unsqueeze = false)
|
||||
{
|
||||
std::optional<tb::LlmRequest::TensorPtr> tensorPtr = std::nullopt;
|
||||
if (atTensor)
|
||||
{
|
||||
tensorPtr = tr::TorchView::of(atTensor.value());
|
||||
if (unsqueeze)
|
||||
{
|
||||
(*tensorPtr)->unsqueeze(0);
|
||||
}
|
||||
}
|
||||
return tensorPtr;
|
||||
};
|
||||
|
||||
auto embedding_bias_tensor_ptr = makeOptionalTensor(embedding_bias, true);
|
||||
auto bad_words_list_tensor_ptr = makeOptionalTensor(bad_words_list, true);
|
||||
auto stop_words_list_tensor_ptr = makeOptionalTensor(stop_words_list, true);
|
||||
auto prompt_embedding_table_tensor_ptr = makeOptionalTensor(prompt_embedding_table);
|
||||
auto multimodal_embedding_tensor_ptr = makeOptionalTensor(multimodal_embedding);
|
||||
auto lora_weights_tensor_ptr = makeOptionalTensor(lora_weights);
|
||||
auto mrope_rotary_cos_sin_tensor_ptr = makeOptionalTensor(mrope_rotary_cos_sin);
|
||||
auto lora_config_tensor_ptr = makeOptionalTensor(lora_config);
|
||||
auto draft_logits_tensor_ptr = makeOptionalTensor(draft_logits);
|
||||
auto encoder_input_features_tensor_ptr = makeOptionalTensor(encoder_input_features);
|
||||
auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask);
|
||||
auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks);
|
||||
|
||||
return tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming,
|
||||
end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr,
|
||||
stop_words_list_tensor_ptr, position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size,
|
||||
multimodal_hashes, multimodal_positions, multimodal_lengths, multimodal_embedding_tensor_ptr,
|
||||
mrope_rotary_cos_sin_tensor_ptr, mrope_position_deltas, lora_task_id, lora_weights_tensor_ptr,
|
||||
lora_config_tensor_ptr, lookahead_config, kv_cache_retention_config, return_log_probs,
|
||||
return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr,
|
||||
exclude_input_from_output, logits_post_processor, apply_logits_post_processor_batched,
|
||||
encoder_input_tokens, return_encoder_output, client_id, priority,
|
||||
encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr,
|
||||
llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config,
|
||||
skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params,
|
||||
language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time};
|
||||
}),
|
||||
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
|
||||
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
|
||||
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
|
||||
py::arg("stop_words_list") = std::nullopt, py::arg("position_ids") = std::nullopt,
|
||||
py::arg("prompt_embedding_table") = std::nullopt, py::arg("prompt_vocab_size") = std::nullopt,
|
||||
py::arg("multimodal_hashes") = std::nullopt, py::arg("multimodal_positions") = std::nullopt,
|
||||
py::arg("multimodal_lengths") = std::nullopt, py::arg("multimodal_embedding") = std::nullopt,
|
||||
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
|
||||
py::arg("lora_task_id") = std::nullopt, py::arg("lora_weights") = std::nullopt,
|
||||
py::arg("lora_config") = std::nullopt, py::arg("lookahead_config") = std::nullopt,
|
||||
py::arg("kv_cache_retention_config") = std::nullopt, py::arg("return_log_probs") = false,
|
||||
py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false,
|
||||
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt,
|
||||
py::arg("exclude_input_from_output") = false, py::arg("logits_post_processor") = std::nullopt,
|
||||
py::arg("apply_logits_post_processor_batched") = false, py::arg("encoder_input_tokens") = std::nullopt,
|
||||
py::arg("return_encoder_output") = false, py::arg("client_id") = std::nullopt,
|
||||
py::arg("priority") = executor::Request::kDefaultPriority, py::arg("encoder_input_features") = std::nullopt,
|
||||
py::arg("encoder_output_len") = std::nullopt, py::arg("cross_attention_mask") = std::nullopt,
|
||||
py::arg_v("llm_request_type", tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
|
||||
"LlmRequestType.LLMREQUEST_TYPE_CONTEXT_AND_GENERATION"),
|
||||
py::arg("input_token_extra_ids") = std::nullopt, py::arg("num_return_sequences") = 1,
|
||||
py::arg("eagle_config") = std::nullopt, py::arg("skip_cross_attn_blocks") = std::nullopt,
|
||||
py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt,
|
||||
py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt,
|
||||
py::arg("context_phase_params") = std::nullopt, py::arg("cache_salt_id") = std::nullopt,
|
||||
py::arg("arrival_time") = std::nullopt)
|
||||
.def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, py::arg("vocab_size"))
|
||||
.def(py::init<tb::LlmRequest const&>())
|
||||
.def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"),
|
||||
py::arg("max_draft_len"), py::arg("vocab_size_padded"), py::arg("max_endocer_input_len") = std::nullopt,
|
||||
py::arg("enable_kv_cache_reuse") = false)
|
||||
.def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false,
|
||||
py::arg("mpi_world_rank") = 0)
|
||||
.def("create_child_request", &tb::LlmRequest::createChildRequest, py::arg("child_id"))
|
||||
.def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false,
|
||||
py::arg("mpi_world_rank") = 0)
|
||||
.def("create_serialized_result",
|
||||
[](tb::LlmRequest& self, bool use_fast_logits = false, int mpi_world_rank = 0)
|
||||
{
|
||||
std::vector<char> serialized_result;
|
||||
bool is_final = false;
|
||||
self.createSerializedResult(serialized_result, is_final, use_fast_logits, mpi_world_rank);
|
||||
return std::make_tuple(py::bytes(serialized_result.data(), serialized_result.size()), is_final);
|
||||
})
|
||||
.def("move_prompt_embedding_table_to_gpu", &tb::LlmRequest::movePromptEmbeddingTableToGpu, py::arg("manager"))
|
||||
.def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager"))
|
||||
.def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason"))
|
||||
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
|
||||
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"))
|
||||
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
|
||||
.def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::sGlobalSteadyClockOffset);
|
||||
|
||||
py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
|
||||
.def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"),
|
||||
py::arg("max_sequence_idle_microseconds"))
|
||||
.def("get_sequence_slot", &tb::SequenceSlotManager::getSequenceSlot, py::arg("start_flag"),
|
||||
py::arg("sequence_id"))
|
||||
.def("free_sequence_slot", &tb::SequenceSlotManager::freeSequenceSlot, py::arg("sequence_id"))
|
||||
.def("free_idle_sequence_slots", &tb::SequenceSlotManager::freeIdleSequenceSlots);
|
||||
|
||||
py::classh<tb::rnn_state_manager::RnnStateManager>(m, "RnnStateManager")
|
||||
.def(py::init<tr::SizeType32, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
|
||||
py::arg("max_num_sequences"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"));
|
||||
|
||||
m.def(
|
||||
"add_new_tokens_to_requests",
|
||||
[](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,
|
||||
std::vector<tb::LlmRequest::TokenIdType> const& tokens, int beam_idx)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(requests.size() == tokens.size(), "Expected the same number of requests and tokens.");
|
||||
|
||||
for (int i = 0; i < requests.size(); ++i)
|
||||
{
|
||||
requests[i]->addNewToken(tokens[i], beam_idx);
|
||||
}
|
||||
},
|
||||
py::arg("requests"), py::arg("tokens"), py::arg("beam_idx"),
|
||||
"Add new tokens to multiple LLM requests. The tokens vector should contain tokens for beam beam_idx of all "
|
||||
"requests in order.");
|
||||
|
||||
m.def(
|
||||
"make_decoding_batch_input",
|
||||
[](tb::DecoderInputBuffers& decoderInputBuffers, runtime::decoder::DecoderState& decoderState,
|
||||
std::vector<std::shared_ptr<tb::LlmRequest>> const& contextRequests,
|
||||
std::vector<std::shared_ptr<tb::LlmRequest>> const& genRequests, tr::ITensor::SharedPtr const& logits,
|
||||
int beamWidth, std::vector<int> const& numContextLogitsPrefixSum, tr::BufferManager const& manager)
|
||||
{
|
||||
std::vector<int> activeSlots;
|
||||
std::vector<int> generationSteps;
|
||||
std::vector<std::vector<tr::ITensor::SharedConstPtr>> logitsVec = {{}};
|
||||
|
||||
for (int i = 0; i < contextRequests.size(); ++i)
|
||||
{
|
||||
if (contextRequests[i]->isLastContextChunk())
|
||||
{
|
||||
activeSlots.push_back(*contextRequests[i]->mSeqSlot);
|
||||
generationSteps.push_back(contextRequests[i]->getDecodingIter());
|
||||
auto contextLogitsOffset = numContextLogitsPrefixSum[i + 1] - 1;
|
||||
tr::ITensor::SharedPtr logitsView = ITensor::slice(logits, contextLogitsOffset, 1);
|
||||
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
// Tile logits of context requests
|
||||
auto const logitsShape = logitsView->getShape();
|
||||
auto const logitsType = logitsView->getDataType();
|
||||
auto decoderLogits = manager.gpu(ITensor::makeShape({beamWidth, logitsShape.d[1]}), logitsType);
|
||||
tensorrt_llm::runtime::kernels::tileTensor(
|
||||
*decoderLogits, *logitsView, beamWidth, manager.getStream());
|
||||
decoderLogits->unsqueeze(0);
|
||||
logitsVec[0].push_back(std::move(decoderLogits));
|
||||
}
|
||||
else
|
||||
{
|
||||
logitsView->unsqueeze(1);
|
||||
logitsVec[0].push_back(std::move(logitsView));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto genLogitsOffset = numContextLogitsPrefixSum.back();
|
||||
for (int i = 0; i < genRequests.size(); ++i)
|
||||
{
|
||||
if (genRequests[i]->isGenerationInProgressState())
|
||||
{
|
||||
activeSlots.push_back(*genRequests[i]->mSeqSlot);
|
||||
generationSteps.push_back(genRequests[i]->getDecodingIter());
|
||||
|
||||
auto logitsOffset = genLogitsOffset + i * beamWidth;
|
||||
auto numberOfLogits = beamWidth;
|
||||
tr::ITensor::SharedPtr logitsView = ITensor::slice(logits, logitsOffset, numberOfLogits);
|
||||
logitsView->unsqueeze(0);
|
||||
logitsVec[0].push_back(std::move(logitsView));
|
||||
}
|
||||
}
|
||||
|
||||
auto& batchSlots = decoderInputBuffers.forwardBatchSlots;
|
||||
batchSlots[0]->resize(activeSlots.size());
|
||||
auto batchSlotsRange = tr::BufferRange<SizeType32>(*batchSlots[0]);
|
||||
for (int i = 0; i < activeSlots.size(); ++i)
|
||||
{
|
||||
batchSlotsRange[i] = activeSlots[i];
|
||||
}
|
||||
|
||||
decoderInputBuffers.batchLogits = logitsVec;
|
||||
|
||||
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
|
||||
if (maxBeamWidth > 1)
|
||||
{
|
||||
// For Variable-Beam-Width-Search
|
||||
decoderState.getJointDecodingInput().generationSteps = generationSteps;
|
||||
}
|
||||
},
|
||||
py::arg("decoder_input_buffers"), py::arg("decoder_state"), py::arg("context_requests"),
|
||||
py::arg("generation_requests"), py::arg("logits"), py::arg("beam_width"),
|
||||
py::arg("num_context_logits_prefix_sum"), py::arg("buffer_manager"), "Make decoding batch input.");
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
@ -1,28 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m);
|
||||
|
||||
}
|
||||
@ -1,75 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "buffers.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
using tr::SizeType32;
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
void Buffers::initBindings(pybind11::module_& m)
|
||||
{
|
||||
py::class_<tb::DecoderInputBuffers>(m, "DecoderInputBuffers")
|
||||
.def(py::init<tr::SizeType32, tr::SizeType32, tr::BufferManager>(), py::arg("max_batch_size"),
|
||||
py::arg("max_tokens_per_engine_step"), py::arg("manager"))
|
||||
.def_readwrite("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots)
|
||||
.def_readwrite("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice)
|
||||
.def_readwrite("fill_values", &tb::DecoderInputBuffers::fillValues)
|
||||
.def_readwrite("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice)
|
||||
.def_readwrite("inputs_ids", &tb::DecoderInputBuffers::inputsIds)
|
||||
.def_readwrite("batch_logits", &tb::DecoderInputBuffers::batchLogits)
|
||||
.def_readwrite("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots)
|
||||
.def_readwrite("decoder_logits", &tb::DecoderInputBuffers::decoderLogits)
|
||||
.def_readwrite("max_decoder_steps", &tb::DecoderInputBuffers::maxDecoderSteps);
|
||||
|
||||
py::class_<tb::DecoderOutputBuffers>(m, "DecoderOutputBuffers")
|
||||
.def_readwrite("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost)
|
||||
.def_readwrite("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost)
|
||||
.def_property_readonly("new_output_tokens_host",
|
||||
[](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); })
|
||||
.def_readwrite("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost)
|
||||
.def_readwrite("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost)
|
||||
.def_readwrite("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost);
|
||||
|
||||
py::class_<tb::SlotDecoderBuffers>(m, "SlotDecoderBuffers")
|
||||
.def(py::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const&>(),
|
||||
py::arg("max_beam_width"), py::arg("max_seq_len"), py::arg("buffer_manager"))
|
||||
.def_readwrite("output_ids", &tb::SlotDecoderBuffers::outputIds)
|
||||
.def_readwrite("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost)
|
||||
.def_readwrite("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost)
|
||||
.def_readwrite("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs)
|
||||
.def_readwrite("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost)
|
||||
.def_readwrite("log_probs", &tb::SlotDecoderBuffers::logProbs)
|
||||
.def_readwrite("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost)
|
||||
.def_readwrite("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost);
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
@ -1,30 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
class Buffers
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
@ -1,204 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cacheTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/bindingUtils.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/extension.h>
|
||||
#include <typeinfo>
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
class PyCacheTransceiver : public tb::BaseCacheTransceiver
|
||||
{
|
||||
public:
|
||||
// using BaseCacheTransceiver::BaseCacheTransceiver; // Inherit constructors
|
||||
|
||||
void respondAndSendAsync(tb::LlmRequest* llmRequest) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, respondAndSendAsync, llmRequest);
|
||||
}
|
||||
|
||||
void requestAndReceiveSync(tb::LlmRequest* llmRequest) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveSync, llmRequest);
|
||||
}
|
||||
|
||||
void requestAndReceiveAsync(tb::LlmRequest* llmRequest) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveAsync, llmRequest);
|
||||
}
|
||||
|
||||
using RequestStatuses = tb::RequestStatuses;
|
||||
|
||||
RequestStatuses checkContextTransferStatus(
|
||||
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
RequestStatuses, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum, markComplete);
|
||||
}
|
||||
|
||||
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, checkGenTransferStatus, atLeastRequestNum);
|
||||
}
|
||||
|
||||
bool checkGenTransferComplete() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(bool, tb::BaseCacheTransceiver, checkGenTransferComplete);
|
||||
}
|
||||
|
||||
bool cancelRequest(tb::LlmRequest* llmRequest) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(bool, tb::BaseCacheTransceiver, cancelRequest, llmRequest);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void tb::CacheTransceiverBindings::initBindings(py::module_& m)
|
||||
{
|
||||
py::classh<tb::BaseCacheTransceiver, PyCacheTransceiver>(m, "BaseCacheTransceiver")
|
||||
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
|
||||
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
|
||||
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
|
||||
.def(
|
||||
"check_context_transfer_status",
|
||||
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
|
||||
{
|
||||
RequestStatuses result;
|
||||
{
|
||||
py::gil_scoped_release release;
|
||||
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
|
||||
}
|
||||
|
||||
auto completedRequestIds
|
||||
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
|
||||
auto errorRequestIds
|
||||
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
|
||||
return py::make_tuple(completedRequestIds, errorRequestIds);
|
||||
},
|
||||
py::arg("at_least_request_num") = std::nullopt, py::arg("mark_complete") = false)
|
||||
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
|
||||
.def("cancel_request", &BaseCacheTransceiver::cancelRequest);
|
||||
|
||||
py::enum_<executor::kv_cache::CacheState::AttentionType>(m, "AttentionType")
|
||||
.value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT)
|
||||
.value("MLA", executor::kv_cache::CacheState::AttentionType::kMLA);
|
||||
|
||||
py::classh<tb::CacheTransceiver, tb::BaseCacheTransceiver>(m, "CacheTransceiver")
|
||||
.def(py::init<tb::kv_cache_manager::BaseKVCacheManager*, std::vector<SizeType32>, SizeType32, SizeType32,
|
||||
runtime::WorldConfig, std::vector<SizeType32>, nvinfer1::DataType,
|
||||
executor::kv_cache::CacheState::AttentionType, std::optional<executor::CacheTransceiverConfig>>(),
|
||||
py::arg("cache_manager"), py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"),
|
||||
py::arg("tokens_per_block"), py::arg("world_config"), py::arg("attention_layer_num_per_pp"),
|
||||
py::arg("dtype"), py::arg("attention_type"), py::arg("cache_transceiver_config") = std::nullopt);
|
||||
|
||||
py::classh<tb::CacheTransceiverComm>(m, "CacheTransceiverComm")
|
||||
.def(py::init(
|
||||
[](py::object pg_obj, std::string pybind11_abi)
|
||||
{
|
||||
return new CacheTransceiverComm(
|
||||
common::get_intrusive_ptr<c10d::ProcessGroup, py::error_already_set>(
|
||||
pg_obj.ptr(), pybind11_abi));
|
||||
}),
|
||||
py::arg("process_group"), py::arg("pybind11_abi"))
|
||||
.def("get_rank", &tb::CacheTransceiverComm::getRank)
|
||||
.def("get_size", &tb::CacheTransceiverComm::getSize)
|
||||
.def("split", &tb::CacheTransceiverComm::split, py::arg("color"), py::arg("key"))
|
||||
.def(
|
||||
"allgather",
|
||||
[](tb::CacheTransceiverComm const& self, int64_t input)
|
||||
{
|
||||
std::vector<int64_t> out(static_cast<size_t>(self.getSize()));
|
||||
c10d::AllgatherOptions options;
|
||||
bool ok = self.allgather(input, std::ref(out), options);
|
||||
return py::make_tuple(ok, out);
|
||||
},
|
||||
py::arg("input"))
|
||||
.def(
|
||||
"allgather",
|
||||
[](tb::CacheTransceiverComm const& self, double input)
|
||||
{
|
||||
std::vector<double> out(static_cast<size_t>(self.getSize()));
|
||||
c10d::AllgatherOptions options;
|
||||
bool ok = self.allgather(input, std::ref(out), options);
|
||||
return py::make_tuple(ok, out);
|
||||
},
|
||||
py::arg("input"))
|
||||
.def(
|
||||
"allgather",
|
||||
[](tb::CacheTransceiverComm const& self, char input)
|
||||
{
|
||||
std::vector<char> out(static_cast<size_t>(self.getSize()));
|
||||
c10d::AllgatherOptions options;
|
||||
bool ok = self.allgather(input, std::ref(out), options);
|
||||
return py::make_tuple(ok, out);
|
||||
},
|
||||
py::arg("input"))
|
||||
.def(
|
||||
"allgatherv",
|
||||
[](tb::CacheTransceiverComm const& self, std::vector<int64_t> input, std::vector<int> const& sizes)
|
||||
{
|
||||
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
std::vector<int64_t> output(total_size);
|
||||
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
|
||||
return py::make_tuple(ok, output);
|
||||
},
|
||||
py::arg("input"), py::arg("sizes"))
|
||||
.def(
|
||||
"allgatherv",
|
||||
[](tb::CacheTransceiverComm const& self, std::vector<double> input, std::vector<int> const& sizes)
|
||||
{
|
||||
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
std::vector<double> output(total_size);
|
||||
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
|
||||
return py::make_tuple(ok, output);
|
||||
},
|
||||
py::arg("input"), py::arg("sizes"))
|
||||
.def(
|
||||
"allgatherv",
|
||||
[](tb::CacheTransceiverComm const& self, std::vector<char> input, std::vector<int> const& sizes)
|
||||
{
|
||||
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||
std::vector<char> output(total_size);
|
||||
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
|
||||
return py::make_tuple(ok, output);
|
||||
},
|
||||
py::arg("input"), py::arg("sizes"));
|
||||
|
||||
py::class_<tb::kv_cache_manager::CacheTransBufferManager>(m, "CacheTransBufferManager")
|
||||
.def(py::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), py::arg("cache_manager"),
|
||||
py::arg("max_num_tokens") = std::nullopt)
|
||||
.def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize,
|
||||
py::arg("cache_size_bytes_per_token_per_window"), py::arg("tokens_per_block"),
|
||||
py::arg("cache_transceiver_config") = py::none());
|
||||
}
|
||||
@ -1,30 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
class CacheTransceiverBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -1,47 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager;
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
class PyKvCacheConnectorManager : public KvCacheConnectorManager, py::trampoline_self_life_support
|
||||
{
|
||||
public:
|
||||
using KvCacheConnectorManager::KvCacheConnectorManager;
|
||||
|
||||
SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE_NAME(SizeType32, KvCacheConnectorManager, "get_num_new_matched_tokens",
|
||||
getNumNewMatchedTokens, request, numComputedTokens);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<tb::kv_connector::KvCacheConnectorManager, PyKvCacheConnectorManager, py::smart_holder>(
|
||||
m, "KvCacheConnectorManager")
|
||||
.def(py::init<>())
|
||||
.def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens,
|
||||
py::arg("request"), py::arg("num_computed_tokens"));
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
class KVCacheManagerConnectorBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager::kv_connector
|
||||
{
|
||||
|
||||
using namespace tensorrt_llm::batch_manager::kv_connector;
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager::kv_connector
|
||||
@ -1,570 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kvCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
|
||||
#include "tensorrt_llm/pybind/common/bindTypes.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tbc = tensorrt_llm::batch_manager::kv_connector;
|
||||
namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace py = pybind11;
|
||||
using BlockKey = tbk::BlockKey;
|
||||
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
|
||||
using VecTokens = std::vector<TokenIdType>;
|
||||
using CudaStreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
|
||||
|
||||
namespace
|
||||
{
|
||||
std::optional<tensorrt_llm::runtime::ITensor::UniquePtr> from_torch(std::optional<at::Tensor> torchPtr)
|
||||
{
|
||||
if (torchPtr)
|
||||
{
|
||||
return tr::TorchView::of(torchPtr.value());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
class PyKvCacheManager : public tbk::BaseKVCacheManager
|
||||
{
|
||||
public:
|
||||
// using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors
|
||||
void allocatePools(bool useUvm = false) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, allocatePools, useUvm);
|
||||
}
|
||||
|
||||
void releasePools() override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, releasePools);
|
||||
}
|
||||
|
||||
void startScheduling() override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, startScheduling);
|
||||
}
|
||||
|
||||
SizeType32 getTokensPerBlock() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getTokensPerBlock);
|
||||
}
|
||||
|
||||
SizeType32 getMaxNumBlocks() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getMaxNumBlocks);
|
||||
}
|
||||
|
||||
SizeType32 getNumPools() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getNumPools);
|
||||
}
|
||||
|
||||
tbk::KvCacheStats getKvCacheStats() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(tbk::KvCacheStats, tbk::BaseKVCacheManager, getKvCacheStats);
|
||||
}
|
||||
|
||||
void addToken(tb::LlmRequest::RequestIdType requestId) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, addToken, requestId);
|
||||
}
|
||||
|
||||
void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
|
||||
tensorrt_llm::common::OptionalRef<tb::LlmRequest> llmRequest = std::nullopt) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth, llmRequest);
|
||||
}
|
||||
|
||||
std::optional<tbk::KVCacheBlock::IdType> removeSequence(tb::LlmRequest::RequestIdType requestId,
|
||||
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest = std::nullopt,
|
||||
bool pinOnRelease = false) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(std::optional<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, removeSequence,
|
||||
requestId, llmRequest, pinOnRelease);
|
||||
}
|
||||
|
||||
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
|
||||
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(std::vector<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
|
||||
requestId, llmRequest, pinBlocks);
|
||||
}
|
||||
|
||||
tbk::GenerationRequest const& getSequence(tb::LlmRequest::RequestIdType requestId) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(tbk::GenerationRequest const&, tbk::BaseKVCacheManager, getSequence, requestId);
|
||||
}
|
||||
|
||||
void schedulingRemoveSequence(tb::LlmRequest::RequestIdType requestId) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, schedulingRemoveSequence, requestId);
|
||||
}
|
||||
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr getBlockPoolPointers() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
tensorrt_llm::runtime::ITensor::UniquePtr, tbk::BaseKVCacheManager, getBlockPoolPointers);
|
||||
}
|
||||
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr getLayerToPoolMapping() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
tensorrt_llm::runtime::ITensor::UniquePtr, tbk::BaseKVCacheManager, getLayerToPoolMapping);
|
||||
}
|
||||
|
||||
void getBlockOffsetsOfBatch(tensorrt_llm::runtime::ITensor& output, SizeType32 firstBatchSlotIdx,
|
||||
SizeType32 batchSize, SizeType32 beamWidth) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
void, tbk::BaseKVCacheManager, getBlockOffsetsOfBatch, output, firstBatchSlotIdx, batchSize, beamWidth);
|
||||
}
|
||||
|
||||
SizeType32 copyBlockOffsets(tensorrt_llm::runtime::ITensor& output, SizeType32 outputSlotOffset,
|
||||
tb::LlmRequest::RequestIdType requestId) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
SizeType32, tbk::BaseKVCacheManager, copyBlockOffsets, output, outputSlotOffset, requestId);
|
||||
}
|
||||
|
||||
bool isEnableBlockReuse() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(bool, tbk::BaseKVCacheManager, isEnableBlockReuse);
|
||||
}
|
||||
|
||||
void rewindKVCache(tb::LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, rewindKVCache, requestId, rewindLengths);
|
||||
}
|
||||
|
||||
bool isCrossKv() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(bool, tbk::BaseKVCacheManager, isCrossKv);
|
||||
}
|
||||
|
||||
std::optional<BlockKey> findNewContextBlock(
|
||||
VecUniqueTokens const& uniqueTokens, tb::LlmRequest const& llmRequest) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
std::optional<BlockKey>, tbk::BaseKVCacheManager, findNewContextBlock, uniqueTokens, llmRequest);
|
||||
}
|
||||
|
||||
void storeContextBlocks(tb::LlmRequest const& llmRequest) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, storeContextBlocks, llmRequest);
|
||||
}
|
||||
|
||||
std::vector<std::vector<SizeType32>> const& getCacheBlockIds(
|
||||
tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(std::vector<std::vector<SizeType32>> const&, tbk::BaseKVCacheManager, getCacheBlockIds,
|
||||
requestId, windowSize);
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::vector<SizeType32>>> getBatchCacheBlockIds(
|
||||
std::vector<tb::LlmRequest::RequestIdType> const& requestIds, SizeType32 windowSize) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(std::vector<std::vector<std::vector<SizeType32>>>, tbk::BaseKVCacheManager,
|
||||
getBatchCacheBlockIds, requestIds, windowSize);
|
||||
}
|
||||
|
||||
SizeType32 getUsedNumBlocks() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getUsedNumBlocks);
|
||||
}
|
||||
|
||||
SizeType32 getNumFreeBlocks() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getNumFreeBlocks);
|
||||
}
|
||||
|
||||
tbk::BlockManager const& getBlockManager() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(tbk::BlockManager const&, tbk::BaseKVCacheManager, getBlockManager);
|
||||
}
|
||||
|
||||
std::deque<tensorrt_llm::executor::KVCacheEvent> getLatestEvents(
|
||||
std::optional<std::chrono::milliseconds> timeout = std::nullopt) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
std::deque<tensorrt_llm::executor::KVCacheEvent>, tbk::BaseKVCacheManager, getLatestEvents, timeout);
|
||||
}
|
||||
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 poolIdx) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getPrimaryPool, poolIdx);
|
||||
}
|
||||
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr getIndexerKCachePool() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getIndexerKCachePool);
|
||||
}
|
||||
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr getUniquePrimaryPool() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getUniquePrimaryPool);
|
||||
}
|
||||
|
||||
SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx);
|
||||
}
|
||||
|
||||
void syncTransferManagerWithBufferManager() override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, syncTransferManagerWithBufferManager);
|
||||
}
|
||||
|
||||
void refreshBlocks() override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks);
|
||||
}
|
||||
|
||||
void flushIterationEvents() override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, flushIterationEvents);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Deduplicate executor bindings KvCacheStats
|
||||
class PyBasePeftCacheManager : public tb::BasePeftCacheManager
|
||||
{
|
||||
public:
|
||||
~PyBasePeftCacheManager() override = default;
|
||||
|
||||
void addRequestPeft(tb::BasePeftCacheManager::LlmRequestPtr llmRequest, bool tryGpuCache = true) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tb::BasePeftCacheManager, addRequestPeft, llmRequest, tryGpuCache);
|
||||
}
|
||||
|
||||
tb::BasePeftCacheManager::PeftTable ensureBatch(tb::RequestVector const& contextRequests,
|
||||
tb::RequestVector const& generationRequests, bool resetGpuCache = false) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(tb::BasePeftCacheManager::PeftTable, tb::BasePeftCacheManager, ensureBatch,
|
||||
contextRequests, generationRequests, resetGpuCache);
|
||||
}
|
||||
|
||||
void resetDeviceCache() override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tb::BasePeftCacheManager, resetDeviceCache);
|
||||
}
|
||||
|
||||
void markRequestDone(tb::LlmRequest const& llmReq, bool pause = false) override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(void, tb::BasePeftCacheManager, markRequestDone, llmReq, pause);
|
||||
}
|
||||
|
||||
tr::SizeType32 getMaxDevicePages() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(tr::SizeType32, tb::BasePeftCacheManager, getMaxDevicePages);
|
||||
}
|
||||
|
||||
tr::SizeType32 getMaxHostPages() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(tr::SizeType32, tb::BasePeftCacheManager, getMaxHostPages);
|
||||
}
|
||||
|
||||
tr::SizeType32 determineNumPages(std::shared_ptr<tb::LlmRequest> llmRequest) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(tr::SizeType32, tb::BasePeftCacheManager, determineNumPages, llmRequest);
|
||||
}
|
||||
|
||||
bool enabled() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(bool, tb::BasePeftCacheManager, enabled);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<tbk::KvCacheStats>(m, "KvCacheStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("max_num_blocks", &tbk::KvCacheStats::maxNumBlocks)
|
||||
.def_readwrite("free_num_blocks", &tbk::KvCacheStats::freeNumBlocks)
|
||||
.def_readwrite("used_num_blocks", &tbk::KvCacheStats::usedNumBlocks)
|
||||
.def_readwrite("tokens_per_block", &tbk::KvCacheStats::toksPerBlock)
|
||||
.def_readwrite("alloc_total_blocks", &tbk::KvCacheStats::allocTotalBlocks)
|
||||
.def_readwrite("alloc_new_blocks", &tbk::KvCacheStats::allocNewBlocks)
|
||||
.def_readwrite("reused_blocks", &tbk::KvCacheStats::reusedBlocks)
|
||||
.def_readwrite("missed_blocks", &tbk::KvCacheStats::missedBlocks)
|
||||
.def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate)
|
||||
.def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize)
|
||||
.def_readonly("allocated_bytes", &tbk::KvCacheStats::allocatedBytes);
|
||||
|
||||
py::class_<tbk::TempAttentionWindowInputs>(m, "TempAttentionWindowInputs")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("paged_context_fmha", &tbk::TempAttentionWindowInputs::pagedContextFMHA)
|
||||
.def_readwrite("max_input_len", &tbk::TempAttentionWindowInputs::maxInputLen)
|
||||
.def_readwrite("max_num_tokens", &tbk::TempAttentionWindowInputs::maxNumTokens);
|
||||
|
||||
py::class_<tbk::BlockKey>(m, "BlockKey")
|
||||
.def(py::init<>())
|
||||
.def(py::init<VecTokens const&, std::optional<tr::LoraTaskIdType>>(), py::arg("tokens"),
|
||||
py::arg("lora_task_id") = std::nullopt)
|
||||
.def(py::init<bool, std::optional<tr::LoraTaskIdType>, VecUniqueTokens const&>(), py::arg("uses_extra_ids"),
|
||||
py::arg("lora_task_id"), py::arg("unique_tokens"))
|
||||
.def_readonly("uses_extra_ids", &tbk::BlockKey::usesExtraIds)
|
||||
.def_readonly("lora_task_id", &tbk::BlockKey::loraTaskId)
|
||||
.def_readonly("unique_tokens", &tbk::BlockKey::uniqueTokens);
|
||||
|
||||
py::class_<tbk::BlockKeyHasher>(m, "BlockKeyHasher")
|
||||
.def_static("hash", &tbk::BlockKeyHasher::hash, py::arg("block_key"), py::arg("parent_hash") = 0);
|
||||
|
||||
py::class_<tbk::KVCacheEventManager, std::shared_ptr<tbk::KVCacheEventManager>>(m, "KVCacheEventManager")
|
||||
.def(py::init<size_t, std::optional<SizeType32>, std::optional<SizeType32>, SizeType32>(),
|
||||
py::arg("max_kv_event_entries"), py::arg("attention_dp_rank") = std::nullopt,
|
||||
py::arg("attention_dp_size") = std::nullopt, py::arg("attention_dp_events_gather_period_ms") = 5);
|
||||
|
||||
py::classh<tbk::BaseKVCacheManager, PyKvCacheManager>(m, "BaseKVCacheManager")
|
||||
.def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, py::arg("config"),
|
||||
py::arg("is_cross_attention"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config"),
|
||||
py::arg("window_size_to_layers"), py::arg("allotted_primary_mem_bytes"),
|
||||
py::arg("allotted_secondary_mem_bytes"), py::arg("extra_cost_memory"), py::arg("kv_factor"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("allocate_pools", &BaseKVCacheManager::allocatePools, py::call_guard<py::gil_scoped_release>())
|
||||
.def("release_pools", &BaseKVCacheManager::releasePools, py::call_guard<py::gil_scoped_release>())
|
||||
.def("start_scheduling", &BaseKVCacheManager::startScheduling, py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("tokens_per_block", &BaseKVCacheManager::getTokensPerBlock)
|
||||
.def_property_readonly("max_num_blocks", &BaseKVCacheManager::getMaxNumBlocks)
|
||||
.def_property_readonly("num_pools", &BaseKVCacheManager::getNumPools)
|
||||
.def("get_kv_cache_stats", &BaseKVCacheManager::getKvCacheStats, py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("max_blocks_per_seq",
|
||||
[](tbk::BaseKVCacheManager& self) { return self.getOffsetTableDimensions().maxBlocksPerSeq; })
|
||||
.def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("add_token", &BaseKVCacheManager::addToken, py::call_guard<py::gil_scoped_release>())
|
||||
.def("add_sequence", &BaseKVCacheManager::addSequence, py::call_guard<py::gil_scoped_release>())
|
||||
.def("remove_sequence", &BaseKVCacheManager::removeSequence, py::call_guard<py::gil_scoped_release>())
|
||||
.def("pin_blocks", &BaseKVCacheManager::pinBlocks, py::call_guard<py::gil_scoped_release>())
|
||||
.def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_block_pool_pointers",
|
||||
[](tbk::BaseKVCacheManager& self)
|
||||
{
|
||||
std::optional<at::Tensor> block_pool_pointers{std::nullopt};
|
||||
auto tensor = self.getBlockPoolPointers();
|
||||
if (tensor)
|
||||
{
|
||||
std::shared_ptr<tensorrt_llm::runtime::ITensor> _tensor = std::move(tensor);
|
||||
block_pool_pointers = tr::Torch::tensor(_tensor);
|
||||
}
|
||||
return block_pool_pointers;
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_block_scale_pool_pointers",
|
||||
[](tbk::BaseKVCacheManager& self)
|
||||
{
|
||||
std::optional<at::Tensor> block_scale_pool_pointers{std::nullopt};
|
||||
auto tensor = self.getBlockScalePoolPointers();
|
||||
if (tensor)
|
||||
{
|
||||
std::shared_ptr<tensorrt_llm::runtime::ITensor> _tensor = std::move(tensor);
|
||||
block_scale_pool_pointers = tr::Torch::tensor(_tensor);
|
||||
}
|
||||
return block_scale_pool_pointers;
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_layer_to_pool_mapping",
|
||||
[](tbk::BaseKVCacheManager& self)
|
||||
{
|
||||
std::optional<at::Tensor> layer_to_pool_mapping{std::nullopt};
|
||||
auto tensor = self.getLayerToPoolMapping();
|
||||
if (tensor)
|
||||
{
|
||||
std::shared_ptr<tensorrt_llm::runtime::ITensor> _tensor = std::move(tensor);
|
||||
layer_to_pool_mapping = tr::Torch::tensor(_tensor);
|
||||
}
|
||||
return layer_to_pool_mapping;
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_primary_pool_data",
|
||||
[](tbk::BaseKVCacheManager& self, SizeType32 layer_idx) -> at::Tensor
|
||||
{
|
||||
auto pool = tr::Torch::tensor(self.getPrimaryPool(layer_idx));
|
||||
auto pool_layer_idx = self.getPoolLayerIdx(layer_idx);
|
||||
return pool.index({torch::indexing::Slice(), pool_layer_idx});
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_indexer_k_cache_pool_data",
|
||||
[](tbk::BaseKVCacheManager& self, SizeType32 layer_idx) -> at::Tensor
|
||||
{
|
||||
auto pool = tr::Torch::tensor(self.getIndexerKCachePool());
|
||||
return pool.index({torch::indexing::Slice(), layer_idx});
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); },
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_block_offsets_of_batch",
|
||||
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize,
|
||||
SizeType32 beamWidth)
|
||||
{
|
||||
auto _output = from_torch(output);
|
||||
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
|
||||
self.getBlockOffsetsOfBatch(*(_output.value()), firstBatchSlotIdx, batchSize, beamWidth);
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"copy_block_offsets",
|
||||
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 outputSlotOffset,
|
||||
tb::LlmRequest::RequestIdType requestId)
|
||||
{
|
||||
auto _output = from_torch(output);
|
||||
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
|
||||
auto maxBlockCount = self.copyBlockOffsets(*(_output.value()), outputSlotOffset, requestId);
|
||||
return maxBlockCount;
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"copy_batch_block_offsets",
|
||||
[](tbk::BaseKVCacheManager& self, at::Tensor output,
|
||||
std::vector<tb::LlmRequest::RequestIdType> const& requestIds, SizeType32 const beamWidth,
|
||||
SizeType32 const offset)
|
||||
{
|
||||
auto _output = from_torch(output);
|
||||
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
|
||||
for (size_t i = 0; i < requestIds.size(); ++i)
|
||||
{
|
||||
self.copyBlockOffsets(*(_output.value()), i * beamWidth + offset, requestIds[i]);
|
||||
}
|
||||
},
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"get_latest_events",
|
||||
[](tbk::BaseKVCacheManager& self, std::optional<double> timeout_ms = std::nullopt)
|
||||
{
|
||||
if (timeout_ms)
|
||||
{
|
||||
return self.getLatestEvents(std::chrono::milliseconds(static_cast<int64_t>(*timeout_ms)));
|
||||
}
|
||||
return self.getLatestEvents(std::nullopt);
|
||||
},
|
||||
py::arg("timeout_ms") = std::nullopt, py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("enable_block_reuse", &BaseKVCacheManager::isEnableBlockReuse)
|
||||
.def("rewind_kv_cache", &BaseKVCacheManager::rewindKVCache, py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("cross_kv", &BaseKVCacheManager::isCrossKv)
|
||||
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, py::call_guard<py::gil_scoped_release>())
|
||||
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, py::arg("unique_tokens"),
|
||||
py::arg("llm_request"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
|
||||
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
|
||||
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::enum_<tbk::CacheType>(m, "CacheType")
|
||||
.value("SELF", tbk::CacheType::kSELF)
|
||||
.value("CROSS", tbk::CacheType::kCROSS)
|
||||
.value("SELFKONLY", tbk::CacheType::kSELFKONLY);
|
||||
|
||||
py::classh<tbk::KVCacheManager, tbk::BaseKVCacheManager>(m, "KVCacheManager")
|
||||
.def(py::init<std::vector<SizeType32> const&, SizeType32, SizeType32,
|
||||
std::map<SizeType32, std::tuple<SizeType32, SizeType32>> const&, SizeType32, SizeType32,
|
||||
std::vector<SizeType32> const&, std::optional<tbk::TempAttentionWindowInputs> const&,
|
||||
nvinfer1::DataType, SizeType32, bool, int64_t, bool, bool, tbk::CacheType,
|
||||
std::optional<tensorrt_llm::executor::RetentionPriority>, std::shared_ptr<tbk::KVCacheEventManager>,
|
||||
bool, bool, std::shared_ptr<tbc::KvCacheConnectorManager>, bool, SizeType32, SizeType32>(),
|
||||
py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"), py::arg("tokens_per_block"),
|
||||
py::arg("blocks_per_window"), py::arg("max_num_sequences"), py::arg("max_beam_width"),
|
||||
py::arg("max_attention_window_vec"), py::arg("temp_attention_window_inputs"), py::arg("dtype"),
|
||||
py::arg("sink_token_length"), py::arg("stream"), py::arg("max_sequence_length"),
|
||||
py::arg("enable_block_reuse") = false, py::arg("onboard_blocks") = true,
|
||||
py::arg_v("cache_type", tbk::CacheType::kSELF, "bindings.internal.batch_manager.CacheType.SELF"),
|
||||
py::arg("secondary_offload_min_priority") = std::nullopt, py::arg("event_manager") = nullptr,
|
||||
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true,
|
||||
py::arg("kv_connector_manager") = nullptr, py::arg("enable_indexer_k_cache") = false,
|
||||
py::arg("indexer_k_cache_quant_block_size") = 128, py::arg("indexer_k_cache_index_head_dim") = 0,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"scheduling_has_free_blocks",
|
||||
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
|
||||
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
|
||||
py::arg("num_required"), py::arg("window_size"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly(
|
||||
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
|
||||
}
|
||||
|
||||
void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
|
||||
{
|
||||
py::classh<tb::BasePeftCacheManager, PyBasePeftCacheManager>(m, "BasePeftCacheManager")
|
||||
.def("add_request_peft", &tb::BasePeftCacheManager::addRequestPeft, py::arg("request"),
|
||||
py::arg("try_gpu_cache") = true, py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"ensure_batch",
|
||||
[](tb::BasePeftCacheManager& self, tb::RequestVector const& contextRequests,
|
||||
tb::RequestVector const& generationRequests, bool resetGpuCache)
|
||||
{ return self.ensureBatch(contextRequests, generationRequests, resetGpuCache); },
|
||||
py::arg("context_requests"), py::arg("generation_requests"), py::arg("reset_gpu_cache") = false,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"reset_device_cache", &tb::BasePeftCacheManager::resetDeviceCache, py::call_guard<py::gil_scoped_release>())
|
||||
.def("mark_request_done", &tb::BasePeftCacheManager::markRequestDone, py::arg("request"),
|
||||
py::arg("pause") = false, py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("max_device_pages", &tb::BasePeftCacheManager::getMaxDevicePages)
|
||||
.def_property_readonly("max_host_pages", &tb::BasePeftCacheManager::getMaxHostPages)
|
||||
.def("determine_num_pages", &tb::BasePeftCacheManager::determineNumPages, py::arg("request"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("enabled", &tb::BasePeftCacheManager::enabled);
|
||||
|
||||
py::classh<tb::PeftCacheManager, tb::BasePeftCacheManager>(m, "PeftCacheManager")
|
||||
.def(py::init<tb::PeftCacheManagerConfig, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
|
||||
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, py::arg("taskId"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_task_cached_device", &tb::PeftCacheManager::isTaskCachedDevice, py::arg("taskId"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("ensure_batch_map_task_id", &tb::PeftCacheManager::ensureBatchMapTaskId, py::arg("context_requests"),
|
||||
py::arg("generation_requests"), py::arg("reset_gpu_cache") = false,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::classh<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager")
|
||||
.def(py::init<>(), py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
class KVCacheManagerBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
class BasePeftCacheManagerBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -1,111 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kvCacheManagerV2Utils.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||
{
|
||||
|
||||
void KVCacheManagerV2UtilsBindings::initBindings(py::module_& module)
|
||||
{
|
||||
// Bind DiskAddress struct
|
||||
py::class_<DiskAddress>(module, "DiskAddress")
|
||||
.def(py::init<int, ssize_t>(), py::arg("fd"), py::arg("pos"))
|
||||
.def_readwrite("fd", &DiskAddress::fd)
|
||||
.def_readwrite("pos", &DiskAddress::pos);
|
||||
|
||||
// Bind Task template instantiations
|
||||
py::class_<Task<DiskAddress, DiskAddress>>(module, "DiskToDiskTask")
|
||||
.def(py::init<DiskAddress, DiskAddress>(), py::arg("dst"), py::arg("src"))
|
||||
.def_readwrite("dst", &Task<DiskAddress, DiskAddress>::dst)
|
||||
.def_readwrite("src", &Task<DiskAddress, DiskAddress>::src);
|
||||
|
||||
py::class_<Task<MemAddress, DiskAddress>>(module, "DiskToHostTask")
|
||||
.def(py::init<MemAddress, DiskAddress>(), py::arg("dst"), py::arg("src"))
|
||||
.def_readwrite("dst", &Task<MemAddress, DiskAddress>::dst)
|
||||
.def_readwrite("src", &Task<MemAddress, DiskAddress>::src);
|
||||
|
||||
py::class_<Task<DiskAddress, MemAddress>>(module, "HostToDiskTask")
|
||||
.def(py::init<DiskAddress, MemAddress>(), py::arg("dst"), py::arg("src"))
|
||||
.def_readwrite("dst", &Task<DiskAddress, MemAddress>::dst)
|
||||
.def_readwrite("src", &Task<DiskAddress, MemAddress>::src);
|
||||
|
||||
py::class_<Task<MemAddress, MemAddress>>(module, "MemToMemTask")
|
||||
.def(py::init<MemAddress, MemAddress>(), py::arg("dst"), py::arg("src"))
|
||||
.def_readwrite("dst", &Task<MemAddress, MemAddress>::dst)
|
||||
.def_readwrite("src", &Task<MemAddress, MemAddress>::src);
|
||||
|
||||
// Bind copy functions
|
||||
module.def(
|
||||
"copy_disk_to_disk",
|
||||
[](std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
|
||||
{ return copyDiskToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
|
||||
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
|
||||
"Copy data from disk to disk using CUDA host function");
|
||||
|
||||
module.def(
|
||||
"copy_disk_to_host",
|
||||
[](std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
|
||||
{ return copyDiskToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
|
||||
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
|
||||
"Copy data from disk to host using CUDA host function");
|
||||
|
||||
module.def(
|
||||
"copy_host_to_disk",
|
||||
[](std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
|
||||
{ return copyHostToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
|
||||
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
|
||||
"Copy data from host to disk using CUDA host function");
|
||||
|
||||
module.def(
|
||||
"copy_host_to_host",
|
||||
[](std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
|
||||
{ return copyHostToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
|
||||
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
|
||||
"Copy data from host to host using CUDA host function");
|
||||
|
||||
module.def(
|
||||
"copy_host_to_device",
|
||||
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
|
||||
{ return copyHostToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
|
||||
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
|
||||
"Copy data from host to device using CUDA kernels");
|
||||
|
||||
module.def(
|
||||
"copy_device_to_host",
|
||||
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
|
||||
{ return copyDeviceToHost(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
|
||||
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
|
||||
"Copy data from device to host using CUDA kernels");
|
||||
|
||||
module.def(
|
||||
"copy_device_to_device",
|
||||
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
|
||||
{ return copyDeviceToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
|
||||
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
|
||||
"Copy data from device to device using CUDA kernels");
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||
@ -1,29 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||
{
|
||||
class KVCacheManagerV2UtilsBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& module);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
||||
@ -1,131 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "llmRequest.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/common/bindTypes.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
using namespace tensorrt_llm::pybind::batch_manager;
|
||||
|
||||
using LlmRequestPtr = std::shared_ptr<tb::LlmRequest>;
|
||||
using RequestList = std::list<LlmRequestPtr>;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
std::optional<tb::LlmRequest::TensorPtr> from_torch(std::optional<LlmRequest::TensorPtr> torchPtr)
|
||||
{
|
||||
if (torchPtr)
|
||||
{
|
||||
return tr::TorchView::of(torchPtr.value());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<tb::LlmRequest::LogitsPostProcessor> LlmRequest::callbackAdapter(
|
||||
std::optional<LlmRequest::LogitsPostProcessor> callback)
|
||||
{
|
||||
if (!callback)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return [callback](RequestIdType reqId, tr::ITensor::SharedPtr& tensor, tb::LlmRequest::BeamTokens const& tokens,
|
||||
tr::BufferManager::CudaStreamPtr stream, std::optional<RequestIdType> clientId)
|
||||
{
|
||||
at::Tensor atTensor = tr::Torch::tensor(tensor);
|
||||
callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap(), clientId);
|
||||
};
|
||||
}
|
||||
|
||||
std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
{
|
||||
|
||||
auto const draftTokens = std::make_shared<std::vector<TokenIdType>>(*mDraftTokens.get());
|
||||
auto const optDraftTokens = std::optional<std::shared_ptr<std::vector<TokenIdType>>>(draftTokens);
|
||||
auto const encoderInputTokens = mEncoderTokens.has_value()
|
||||
? std::make_shared<std::vector<TokenIdType>>(*mEncoderTokens.value().get())
|
||||
: nullptr;
|
||||
auto const optEncoderInputTokens = std::optional<std::shared_ptr<std::vector<TokenIdType>>>(encoderInputTokens);
|
||||
return std::make_shared<tb::LlmRequest>( //
|
||||
mRequestId, //
|
||||
mMaxNewTokens, //
|
||||
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), //
|
||||
mSamplingConfig, //
|
||||
mIsStreaming, //
|
||||
mEndId, //
|
||||
mPadId, //
|
||||
from_torch(mEmbeddingBias), //
|
||||
from_torch(mBadWordsList), //
|
||||
from_torch(mStopWordsList), //
|
||||
mPositionIds, //
|
||||
from_torch(mPromptEmbeddingTable), //
|
||||
mPromptVocabSize, //
|
||||
mMultimodalHashes, //
|
||||
mMultimodalPositions, //
|
||||
mMultimodalLengths, //
|
||||
from_torch(mMultimodalEmbedding), //
|
||||
from_torch(mMropeRotaryCosSin), //
|
||||
mMropePositionDeltas, //
|
||||
mLoraTaskId, //
|
||||
from_torch(mLoraWeights), //
|
||||
from_torch(mLoraConfig), //
|
||||
mLookaheadConfig, //
|
||||
mKvCacheRetentionConfig, //
|
||||
mReturnLogProbs, //
|
||||
mReturnContextLogits, //
|
||||
mReturnGenerationLogits, //
|
||||
optDraftTokens, //
|
||||
from_torch(mDraftLogits), //
|
||||
mExcludeInputFromOutput, //
|
||||
callbackAdapter(mLogitsPostProcessor), //
|
||||
mApplyLogitsPostProcessorBatched, //
|
||||
optEncoderInputTokens, //
|
||||
mReturnEncoderOutput, //
|
||||
mClientId, //
|
||||
mPriority, //
|
||||
from_torch(mEncoderInputFeatures), //
|
||||
mEncoderOutputLength, //
|
||||
from_torch(mCrossAttentionMask), //
|
||||
getLlmRequestType(), //
|
||||
std::nullopt, // inputTokenExtraIds
|
||||
mNumReturnSequences, //
|
||||
mEagleConfig, //
|
||||
from_torch(mSkipCrossAttnBlocks), //
|
||||
false, // returnPerfMetrics
|
||||
mGuidedDecodingParams, //
|
||||
mLanguageAdapterUid, //
|
||||
mAllottedTimeMs, //
|
||||
mContextPhaseParams, //
|
||||
mCacheSaltID, //
|
||||
mPerfMetrics.timingMetrics.arrivalTime //
|
||||
);
|
||||
}
|
||||
@ -1,162 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
/* Unfortunately, torch's default pybind bindings don't know about c10::cuda::CUDAStream,
|
||||
* so we have to pass the more generic c10::Stream, and convert it back to a full-fledged
|
||||
* torch.cuda.Stream in python. See example in test/bindings/test_gpt_manager.py
|
||||
*/
|
||||
class LlmRequest : public tb::GenericLlmRequest<at::Tensor, c10::Stream>
|
||||
{
|
||||
public:
|
||||
using Base = GenericLlmRequest<at::Tensor, c10::Stream>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
using SizeType32 = Base::SizeType32;
|
||||
using TokenIdType = Base::TokenIdType;
|
||||
using RequestIdType = Base::RequestIdType;
|
||||
using LoraTaskIdType = Base::LoraTaskIdType;
|
||||
using VecLogProbs = Base::VecLogProbs;
|
||||
using BeamTokens = Base::BeamTokens;
|
||||
using VecTokens = Base::VecTokens;
|
||||
using VecTokenExtraIds = Base::VecTokenExtraIds;
|
||||
using LogitsPostProcessor = Base::LogitsPostProcessor;
|
||||
using CacheSaltIDType = Base::CacheSaltIDType;
|
||||
|
||||
LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector<TokenIdType> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
|
||||
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<std::vector<SizeType32>> positionIds = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType32> promptVocabSize = std::nullopt,
|
||||
std::optional<std::vector<std::vector<SizeType32>>> multimodalHashes = std::nullopt,
|
||||
std::optional<std::vector<SizeType32>> multimodalPositions = std::nullopt,
|
||||
std::optional<std::vector<SizeType32>> multimodalLengths = std::nullopt,
|
||||
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
|
||||
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
|
||||
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
|
||||
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
|
||||
std::optional<TensorPtr> loraConfig = std::nullopt,
|
||||
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
|
||||
std::optional<executor::KvCacheRetentionConfig> kvCacheRetentionConfig = std::nullopt,
|
||||
bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
|
||||
std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt,
|
||||
bool excludeInputFromOutput = false, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
|
||||
bool applyLogitsPostProcessorBatched = false, std::optional<VecTokens> encoderInputTokens = std::nullopt,
|
||||
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
|
||||
executor::PriorityType priority = executor::Request::kDefaultPriority,
|
||||
std::optional<TensorPtr> encoderInputFeatures = std::nullopt,
|
||||
std::optional<SizeType32> encoderOutputLength = std::nullopt,
|
||||
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
|
||||
tb::LlmRequestType llmRequestType = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
|
||||
std::optional<VecTokenExtraIds> inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1,
|
||||
std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
|
||||
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt, bool returnPerfMetrics = false,
|
||||
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
|
||||
std::optional<SizeType32> languageAdapterUid = std::nullopt,
|
||||
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
|
||||
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
|
||||
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt)
|
||||
: Base(requestId, //
|
||||
maxNewTokens, //
|
||||
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //
|
||||
samplingConfig, //
|
||||
isStreaming, //
|
||||
endId, //
|
||||
padId, //
|
||||
embeddingBias, //
|
||||
badWordsList, //
|
||||
stopWordsList, //
|
||||
positionIds.has_value() ? std::make_shared<std::vector<SizeType32>>(std::move(positionIds.value())) //
|
||||
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt), //
|
||||
promptEmbeddingTable, //
|
||||
promptVocabSize, //
|
||||
multimodalHashes.has_value()
|
||||
? std::make_optional(
|
||||
std::make_shared<std::vector<std::vector<SizeType32>>>(std::move(multimodalHashes.value()))) //
|
||||
: std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>>(std::nullopt), //
|
||||
multimodalPositions.has_value()
|
||||
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalPositions.value())) //
|
||||
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt), //
|
||||
multimodalLengths.has_value()
|
||||
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalLengths.value())) //
|
||||
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt), //
|
||||
multimodalEmbedding, //
|
||||
mropeRotaryCosSin, //
|
||||
mropePositionDeltas, //
|
||||
loraTaskId, //
|
||||
loraWeights, //
|
||||
loraConfig, //
|
||||
lookaheadConfig, //
|
||||
kvCacheRetentionConfig, //
|
||||
returnLogProbs, //
|
||||
returnContextLogits, //
|
||||
returnGenerationLogits, //
|
||||
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value())) //
|
||||
: std::make_shared<VecTokens>(), //
|
||||
draftLogits, //
|
||||
excludeInputFromOutput, //
|
||||
logitsPostProcessor, //
|
||||
applyLogitsPostProcessorBatched, //
|
||||
encoderInputTokens ? std::make_optional(std::make_shared<VecTokens>(std::move(*encoderInputTokens))) //
|
||||
: std::optional<std::shared_ptr<VecTokens>>(std::nullopt), //
|
||||
returnEncoderOutput, //
|
||||
clientId, //
|
||||
priority, //
|
||||
encoderInputFeatures, //
|
||||
encoderOutputLength, //
|
||||
crossAttentionMask, //
|
||||
llmRequestType, //
|
||||
inputTokenExtraIds //
|
||||
? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds))) //
|
||||
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt), //
|
||||
numReturnSequences, //
|
||||
eagleConfig, //
|
||||
skipCrossAttnBlocks, //
|
||||
returnPerfMetrics, //
|
||||
guidedDecodingParams, //
|
||||
languageAdapterUid, //
|
||||
allottedTimeMs, //
|
||||
contextPhaseParams, //
|
||||
cacheSaltID, //
|
||||
arrivalTime //
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
static std::optional<tb::LlmRequest::LogitsPostProcessor> callbackAdapter(
|
||||
std::optional<LlmRequest::LogitsPostProcessor> callback);
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
@ -1,518 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/algorithms.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/bindings.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/buffers.h"
|
||||
|
||||
#include "tensorrt_llm/pybind/batch_manager/cacheTransceiver.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/common/tllmExceptions.h"
|
||||
#include "tensorrt_llm/pybind/executor/bindings.h"
|
||||
#include "tensorrt_llm/pybind/process_group/bindings.h"
|
||||
#include "tensorrt_llm/pybind/runtime/bindings.h"
|
||||
#include "tensorrt_llm/pybind/testing/modelSpecBinding.h"
|
||||
#include "tensorrt_llm/pybind/thop/bindings.h"
|
||||
#include "tensorrt_llm/pybind/userbuffers/bindings.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
|
||||
#include "tensorrt_llm/runtime/memoryCounters.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tpb = tensorrt_llm::pybind::batch_manager;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
using SizeType32 = tr::SizeType32;
|
||||
using TokenIdType = tr::TokenIdType;
|
||||
template <typename T>
|
||||
using OptVec = std::optional<std::vector<T>>;
|
||||
|
||||
#if not defined(TRTLLM_PYBIND_MODULE)
|
||||
#error "TRTLLM_PYBIND_MODULE must be defined"
|
||||
#endif
|
||||
|
||||
namespace
|
||||
{
|
||||
tr::SamplingConfig makeSamplingConfig(std::vector<tr::SamplingConfig> const& configs)
|
||||
{
|
||||
return tr::SamplingConfig(configs);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
{
|
||||
m.doc() = "TensorRT LLM Python bindings for C++ runtime";
|
||||
m.attr("binding_type") = "pybind";
|
||||
|
||||
// Create MpiComm binding first since it's used in the executor bindings
|
||||
py::classh<tensorrt_llm::mpi::MpiComm>(m, "MpiComm")
|
||||
.def_static("rank",
|
||||
[]()
|
||||
{
|
||||
auto& session = tensorrt_llm::mpi::MpiComm::session();
|
||||
return session.tensorrt_llm::mpi::MpiComm::getRank();
|
||||
})
|
||||
.def_static("size",
|
||||
[]()
|
||||
{
|
||||
auto& session = tensorrt_llm::mpi::MpiComm::session();
|
||||
return session.tensorrt_llm::mpi::MpiComm::getSize();
|
||||
})
|
||||
.def_static("local_size",
|
||||
[]()
|
||||
{
|
||||
auto& session = tensorrt_llm::mpi::MpiComm::localSession();
|
||||
return session.tensorrt_llm::mpi::MpiComm::getSize();
|
||||
})
|
||||
.def_static("local_init", []() { tensorrt_llm::mpi::MpiComm::localSession(); })
|
||||
.def_static("set_raw_mpi_session_by_fortran_handle",
|
||||
[](int64_t fortran_handle) { tensorrt_llm::mpi::MpiComm::setRawSessionByFortran(fortran_handle); })
|
||||
.def_static("split",
|
||||
[](size_t color, size_t rank)
|
||||
{
|
||||
auto& world = tensorrt_llm::mpi::MpiComm::world();
|
||||
tensorrt_llm::mpi::MpiComm::setSession(world.split(color, rank));
|
||||
});
|
||||
|
||||
py::classh<tr::CudaStream>(m, "CudaStream")
|
||||
.def(py::init(
|
||||
[](py::object py_stream)
|
||||
{
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(py_stream.cast<uintptr_t>());
|
||||
return tr::CudaStream{stream};
|
||||
}),
|
||||
py::arg("stream_ptr"))
|
||||
.def("get_device", &tr::CudaStream::getDevice);
|
||||
|
||||
// Create submodule for executor bindings.
|
||||
auto mExecutor = m.def_submodule("executor", "Executor bindings");
|
||||
auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime");
|
||||
auto mInternalProcessGroup = mInternal.def_submodule("process_group", "PyTorch ProcessGroup internal bindings");
|
||||
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
|
||||
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
|
||||
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
|
||||
auto mInternalBatchManagerKvCacheV2Utils
|
||||
= mInternalBatchManager.def_submodule("kv_cache_manager_v2_utils", "KV Cache Manager V2 Utils bindings");
|
||||
auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings");
|
||||
auto mExceptions = m.def_submodule("exceptions", "Exceptions internal bindings");
|
||||
|
||||
tensorrt_llm::pybind::executor::initBindings(mExecutor);
|
||||
tensorrt_llm::pybind::runtime::initBindingsEarly(mInternalRuntime);
|
||||
tensorrt_llm::pybind::common::initExceptionsBindings(mExceptions);
|
||||
tensorrt_llm::pybind::thop::initBindings(mInternalThop);
|
||||
|
||||
auto buildInfo = m.def_submodule("BuildInfo");
|
||||
buildInfo.attr("ENABLE_MULTI_DEVICE") = py::int_(ENABLE_MULTI_DEVICE);
|
||||
|
||||
py::class_<tb::PeftCacheManagerConfig>(m, "PeftCacheManagerConfig")
|
||||
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32,
|
||||
SizeType32, std::optional<float>, std::optional<size_t>, std::optional<std::string>>(),
|
||||
py::arg("num_host_module_layer") = 0, py::arg("num_device_module_layer") = 0,
|
||||
py::arg("optimal_adapter_size") = 8, py::arg("max_adapter_size") = 64, py::arg("num_put_workers") = 1,
|
||||
py::arg("num_ensure_workers") = 1, py::arg("num_copy_streams") = 1,
|
||||
py::arg("max_pages_per_block_host") = 24, py::arg("max_pages_per_block_device") = 8,
|
||||
py::arg("device_cache_percent") = std::nullopt, py::arg("host_cache_size") = std::nullopt,
|
||||
py::arg("lora_prefetch_dir") = std::nullopt)
|
||||
.def_readwrite("num_host_module_layer", &tb::PeftCacheManagerConfig::numHostModuleLayer)
|
||||
.def_readwrite("num_device_module_layer", &tb::PeftCacheManagerConfig::numDeviceModuleLayer)
|
||||
.def_readwrite("optimal_adapter_size", &tb::PeftCacheManagerConfig::optimalAdapterSize)
|
||||
.def_readwrite("max_adapter_size", &tb::PeftCacheManagerConfig::maxAdapterSize)
|
||||
.def_readwrite("num_put_workers", &tb::PeftCacheManagerConfig::numPutWorkers)
|
||||
.def_readwrite("num_ensure_workers", &tb::PeftCacheManagerConfig::numEnsureWorkers)
|
||||
.def_readwrite("num_copy_streams", &tb::PeftCacheManagerConfig::numCopyStreams)
|
||||
.def_readwrite("max_pages_per_block_host", &tb::PeftCacheManagerConfig::maxPagesPerBlockHost)
|
||||
.def_readwrite("max_pages_per_block_device", &tb::PeftCacheManagerConfig::maxPagesPerBlockDevice)
|
||||
.def_readwrite("device_cache_percent", &tb::PeftCacheManagerConfig::deviceCachePercent)
|
||||
.def_readwrite("host_cache_size", &tb::PeftCacheManagerConfig::hostCacheSize)
|
||||
.def_readwrite("lora_prefetch_dir", &tb::PeftCacheManagerConfig::loraPrefetchDir);
|
||||
|
||||
py::enum_<nvinfer1::DataType>(m, "DataType")
|
||||
.value("FLOAT", nvinfer1::DataType::kFLOAT)
|
||||
.value("HALF", nvinfer1::DataType::kHALF)
|
||||
.value("INT8", nvinfer1::DataType::kINT8)
|
||||
.value("INT32", nvinfer1::DataType::kINT32)
|
||||
.value("BOOL", nvinfer1::DataType::kBOOL)
|
||||
.value("UINT8", nvinfer1::DataType::kUINT8)
|
||||
.value("FP8", nvinfer1::DataType::kFP8)
|
||||
.value("BF16", nvinfer1::DataType::kBF16)
|
||||
.value("INT64", nvinfer1::DataType::kINT64)
|
||||
.value("NVFP4", nvinfer1::DataType::kFP4)
|
||||
.export_values();
|
||||
|
||||
py::enum_<tr::ModelConfig::ModelVariant>(m, "GptModelVariant")
|
||||
.value("GPT", tr::ModelConfig::ModelVariant::kGpt)
|
||||
.value("GLM", tr::ModelConfig::ModelVariant::kGlm)
|
||||
.value("CHATGLM", tr::ModelConfig::ModelVariant::kChatGlm)
|
||||
.value("MAMBA", tr::ModelConfig::ModelVariant::kMamba)
|
||||
.value("RECURRENTGEMMA", tr::ModelConfig::ModelVariant::kRecurrentGemma);
|
||||
|
||||
py::enum_<tr::ModelConfig::KVCacheType>(m, "KVCacheType")
|
||||
.value("CONTINUOUS", tr::ModelConfig::KVCacheType::kCONTINUOUS)
|
||||
.value("PAGED", tr::ModelConfig::KVCacheType::kPAGED)
|
||||
.value("DISABLED", tr::ModelConfig::KVCacheType::kDISABLED)
|
||||
.def("from_string", &tr::ModelConfig::KVCacheTypeFromString);
|
||||
|
||||
py::enum_<tr::ModelConfig::LayerType>(m, "LayerType")
|
||||
.value("ATTENTION", tr::ModelConfig::LayerType::kATTENTION)
|
||||
.value("RECURRENT", tr::ModelConfig::LayerType::kRECURRENT);
|
||||
|
||||
py::enum_<tr::LoraModule::ModuleType>(m, "LoraModuleType")
|
||||
.value("INVALID", tr::LoraModule::ModuleType::kINVALID)
|
||||
.value("ATTN_QKV", tr::LoraModule::ModuleType::kATTN_QKV)
|
||||
.value("ATTN_Q", tr::LoraModule::ModuleType::kATTN_Q)
|
||||
.value("ATTN_K", tr::LoraModule::ModuleType::kATTN_K)
|
||||
.value("ATTN_V", tr::LoraModule::ModuleType::kATTN_V)
|
||||
.value("ATTN_DENSE", tr::LoraModule::ModuleType::kATTN_DENSE)
|
||||
.value("MLP_H_TO_4H", tr::LoraModule::ModuleType::kMLP_H_TO_4H)
|
||||
.value("MLP_4H_TO_H", tr::LoraModule::ModuleType::kMLP_4H_TO_H)
|
||||
.value("MLP_GATE", tr::LoraModule::ModuleType::kMLP_GATE)
|
||||
.value("CROSS_ATTN_QKV", tr::LoraModule::ModuleType::kCROSS_ATTN_QKV)
|
||||
.value("CROSS_ATTN_Q", tr::LoraModule::ModuleType::kCROSS_ATTN_Q)
|
||||
.value("CROSS_ATTN_K", tr::LoraModule::ModuleType::kCROSS_ATTN_K)
|
||||
.value("CROSS_ATTN_V", tr::LoraModule::ModuleType::kCROSS_ATTN_V)
|
||||
.value("CROSS_ATTN_DENSE", tr::LoraModule::ModuleType::kCROSS_ATTN_DENSE)
|
||||
.value("MOE_H_TO_4H", tr::LoraModule::ModuleType::kMOE_H_TO_4H)
|
||||
.value("MOE_4H_TO_H", tr::LoraModule::ModuleType::kMOE_4H_TO_H)
|
||||
.value("MOE_GATE", tr::LoraModule::ModuleType::kMOE_GATE)
|
||||
.value("MOE_ROUTER", tr::LoraModule::ModuleType::kMOE_ROUTER)
|
||||
.value("MLP_ROUTER", tr::LoraModule::ModuleType::kMLP_ROUTER)
|
||||
.value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP);
|
||||
|
||||
py::class_<tr::LoraModule>(m, "LoraModule")
|
||||
.def(py::init<tr::LoraModule::ModuleType, SizeType32, SizeType32, bool, bool, SizeType32, SizeType32>(),
|
||||
py::arg("module_type"), py::arg("in_dim"), py::arg("out_dim"), py::arg("in_dim_first"),
|
||||
py::arg("out_dim_first"), py::arg("in_tp_split_dim"), py::arg("out_tp_split_dim"))
|
||||
.def_property_readonly("module_type", &tr::LoraModule::name)
|
||||
.def_property_readonly("in_dim", &tr::LoraModule::inDim)
|
||||
.def_property_readonly("out_dim", &tr::LoraModule::outDim)
|
||||
.def_property_readonly("in_dim_first", &tr::LoraModule::inDimFirst)
|
||||
.def_property_readonly("out_dim_first", &tr::LoraModule::outDimFirst)
|
||||
.def_property_readonly("in_tp_split_dim", &tr::LoraModule::inTpSplitDim)
|
||||
.def_property_readonly("out_tp_split_dim", &tr::LoraModule::outTpSplitDim)
|
||||
.def_static("create_lora_modules", &tr::LoraModule::createLoraModules, py::arg("lora_module_names"),
|
||||
py::arg("hidden_size"), py::arg("mlp_hidden_size"), py::arg("num_attention_heads"),
|
||||
py::arg("num_kv_attention_heads"), py::arg("attention_head_size"), py::arg("tp_size") = 1,
|
||||
py::arg("num_experts") = 0);
|
||||
|
||||
py::class_<tc::QuantMode>(m, "QuantMode")
|
||||
.def_static("none", &tc::QuantMode::none)
|
||||
.def_static("int4_weights", &tc::QuantMode::int4Weights)
|
||||
.def_static("int8_weights", &tc::QuantMode::int8Weights)
|
||||
.def_static("activations", &tc::QuantMode::activations)
|
||||
.def_static("per_channel_scaling", &tc::QuantMode::perChannelScaling)
|
||||
.def_static("per_token_scaling", &tc::QuantMode::perTokenScaling)
|
||||
.def_static("per_group_scaling", &tc::QuantMode::perGroupScaling)
|
||||
.def_static("int8_kv_cache", &tc::QuantMode::int8KvCache)
|
||||
.def_static("fp8_kv_cache", &tc::QuantMode::fp8KvCache)
|
||||
.def_static("fp8_qdq", &tc::QuantMode::fp8Qdq)
|
||||
.def_property_readonly("value", &tc::QuantMode::value)
|
||||
.def("is_set", &tc::QuantMode::isSet, py::arg("mode"))
|
||||
.def_property_readonly("has_int4_weights", &tc::QuantMode::hasInt4Weights)
|
||||
.def_property_readonly("has_int8_weights", &tc::QuantMode::hasInt8Weights)
|
||||
.def_property_readonly("has_activations", &tc::QuantMode::hasActivations)
|
||||
.def_property_readonly("has_per_channel_scaling", &tc::QuantMode::hasPerChannelScaling)
|
||||
.def_property_readonly("has_per_token_scaling", &tc::QuantMode::hasPerTokenScaling)
|
||||
.def_property_readonly("has_per_group_scaling", &tc::QuantMode::hasPerGroupScaling)
|
||||
.def_property_readonly("has_static_activation_scaling", &tc::QuantMode::hasStaticActivationScaling)
|
||||
.def_property_readonly("has_int8_kv_cache", &tc::QuantMode::hasInt8KvCache)
|
||||
.def_property_readonly("has_fp4_kv_cache", &tc::QuantMode::hasFp4KvCache)
|
||||
.def_property_readonly("has_fp8_kv_cache", &tc::QuantMode::hasFp8KvCache)
|
||||
.def_property_readonly("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq)
|
||||
.def_property_readonly("has_nvfp4", &tc::QuantMode::hasNvfp4)
|
||||
.def_property_readonly("has_w4a8_mxfp4_fp8", &tc::QuantMode::hasW4a8Mxfp4Fp8)
|
||||
.def_property_readonly("has_w4a8_mxfp4_mxfp8", &tc::QuantMode::hasW4a8Mxfp4Mxfp8)
|
||||
.def_property_readonly("has_w4a16_mxfp4", &tc::QuantMode::hasW4a16Mxfp4)
|
||||
.def_property_readonly("has_kv_cache_quant", &tc::QuantMode::hasKvCacheQuant)
|
||||
.def_static("from_description", &tc::QuantMode::fromDescription, py::arg("quantize_weights"),
|
||||
py::arg("quantize_activations"), py::arg("per_token"), py::arg("per_channel"), py::arg("per_group"),
|
||||
py::arg("use_int4_weights"), py::arg("use_int8_kv_cache"), py::arg("use_fp8_kv_kache"),
|
||||
py::arg("use_fp8_qdq"), py::arg("use_fp8_rowwise"), py::arg("use_w4a8_qserve"), py::arg("use_nvfp4"),
|
||||
py::arg("use_fp8_block_scales"), py::arg("use_w4a8_mxfp4_fp8"), py::arg("use_w4a8_mxfp4_mxfp8"),
|
||||
py::arg("use_w4a16_mxfp4"))
|
||||
.def_static("use_smooth_quant", &tc::QuantMode::useSmoothQuant, py::arg("per_token") = false,
|
||||
py::arg("per_channel") = false)
|
||||
.def_static("use_weight_only", &tc::QuantMode::useWeightOnly, py::arg("use_int4_weights") = false,
|
||||
py::arg("per_group") = false)
|
||||
.def_static("from_quant_algo", &tc::QuantMode::fromQuantAlgo, py::arg("quant_algo") = py::none(),
|
||||
py::arg("kv_cache_quant_algo") = py::none())
|
||||
.def(py::self + py::self)
|
||||
.def(py::self += py::self)
|
||||
.def(py::self - py::self)
|
||||
.def(py::self -= py::self)
|
||||
.def(py::self == py::self)
|
||||
.def(py::self != py::self);
|
||||
|
||||
py::class_<tr::ModelConfig>(m, "ModelConfig")
|
||||
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, nvinfer1::DataType>(),
|
||||
py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"),
|
||||
py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type"))
|
||||
.def_property_readonly("vocab_size", &tr::ModelConfig::getVocabSize)
|
||||
.def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size"))
|
||||
.def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1,
|
||||
py::arg("pipeline_parallelism_rank") = 0)
|
||||
.def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1,
|
||||
py::arg("pipeline_parallelism_rank") = 0)
|
||||
.def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1,
|
||||
py::arg("pipeline_parallelism_rank") = 0)
|
||||
.def("num_kv_heads", &tr::ModelConfig::getNbKvHeads, py::arg("layer_idx"))
|
||||
.def("set_num_kv_heads", &tr::ModelConfig::setNbKvHeads, py::arg("num_kv_heads"))
|
||||
.def_property_readonly("num_heads", &tr::ModelConfig::getNbHeads)
|
||||
.def_property_readonly("hidden_size", &tr::ModelConfig::getHiddenSize)
|
||||
.def_property_readonly("size_per_head", &tr::ModelConfig::getSizePerHead)
|
||||
.def_property_readonly("data_type", &tr::ModelConfig::getDataType)
|
||||
.def_property_readonly("speculative_decoding_mode", &tr::ModelConfig::getSpeculativeDecodingMode)
|
||||
.def_property("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead)
|
||||
.def_property(
|
||||
"num_kv_heads_per_layer", &tr::ModelConfig::getNumKvHeadsPerLayer, &tr::ModelConfig::setNumKvHeadsPerLayer)
|
||||
.def_property("use_gpt_attention_plugin",
|
||||
py::overload_cast<>(&tr::ModelConfig::useGptAttentionPlugin, py::const_),
|
||||
py::overload_cast<bool>(&tr::ModelConfig::useGptAttentionPlugin))
|
||||
.def_property("use_packed_input", py::overload_cast<>(&tr::ModelConfig::usePackedInput, py::const_),
|
||||
py::overload_cast<bool>(&tr::ModelConfig::usePackedInput))
|
||||
.def_property("kv_cache_type", py::overload_cast<>(&tr::ModelConfig::getKVCacheType, py::const_),
|
||||
py::overload_cast<tr::ModelConfig::KVCacheType>(&tr::ModelConfig::setKVCacheType))
|
||||
.def_property("tokens_per_block", &tr::ModelConfig::getTokensPerBlock, &tr::ModelConfig::setTokensPerBlock)
|
||||
.def_property("quant_mode", &tr::ModelConfig::getQuantMode, &tr::ModelConfig::setQuantMode)
|
||||
.def_property_readonly("supports_inflight_batching", &tr::ModelConfig::supportsInflightBatching)
|
||||
.def_property("max_batch_size", &tr::ModelConfig::getMaxBatchSize, &tr::ModelConfig::setMaxBatchSize)
|
||||
.def_property("max_beam_width", &tr::ModelConfig::getMaxBeamWidth, &tr::ModelConfig::setMaxBeamWidth)
|
||||
.def_property("max_input_len", &tr::ModelConfig::getMaxInputLen, &tr::ModelConfig::setMaxInputLen)
|
||||
.def_property("max_seq_len", &tr::ModelConfig::getMaxSequenceLen, &tr::ModelConfig::setMaxSequenceLen)
|
||||
.def_property("max_num_tokens", &tr::ModelConfig::getMaxNumTokens, &tr::ModelConfig::setMaxNumTokens)
|
||||
.def_property("max_prompt_embedding_table_size", &tr::ModelConfig::getMaxPromptEmbeddingTableSize,
|
||||
&tr::ModelConfig::setMaxPromptEmbeddingTableSize)
|
||||
.def_property_readonly("use_prompt_tuning", &tr::ModelConfig::usePromptTuning)
|
||||
.def_property_readonly("use_mrope", &tr::ModelConfig::useMrope)
|
||||
.def_property("use_lora_plugin", py::overload_cast<>(&tr::ModelConfig::useLoraPlugin, py::const_),
|
||||
py::overload_cast<bool>(&tr::ModelConfig::useLoraPlugin))
|
||||
.def_property("layer_types", &tr::ModelConfig::getLayerTypes, &tr::ModelConfig::setLayerTypes)
|
||||
.def_property("compute_context_logits", py::overload_cast<>(&tr::ModelConfig::computeContextLogits, py::const_),
|
||||
py::overload_cast<bool>(&tr::ModelConfig::computeContextLogits))
|
||||
.def_property("compute_generation_logits",
|
||||
py::overload_cast<>(&tr::ModelConfig::computeGenerationLogits, py::const_),
|
||||
py::overload_cast<bool>(&tr::ModelConfig::computeGenerationLogits))
|
||||
.def_property("model_variant", &tr::ModelConfig::getModelVariant, &tr::ModelConfig::setModelVariant)
|
||||
.def_property(
|
||||
"use_cross_attention", &tr::ModelConfig::useCrossAttention, &tr::ModelConfig::setUseCrossAttention)
|
||||
.def_property("lora_modules", &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules)
|
||||
.def_property("max_lora_rank", &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank)
|
||||
.def_property("mlp_hidden_size", &tr::ModelConfig::getMlpHiddenSize, &tr::ModelConfig::setMlpHiddenSize)
|
||||
.def_property("size_per_head", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead);
|
||||
|
||||
py::class_<tr::WorldConfig>(m, "WorldConfig")
|
||||
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32,
|
||||
std::optional<std::vector<SizeType32>> const&, bool>(),
|
||||
py::arg("tensor_parallelism") = 1, py::arg("pipeline_parallelism") = 1, py::arg("context_parallelism") = 1,
|
||||
py::arg("rank") = 0, py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode,
|
||||
py::arg("device_ids") = py::none(), py::arg("enable_attention_dp") = false)
|
||||
.def_property_readonly("size", &tr::WorldConfig::getSize)
|
||||
.def_property_readonly("tensor_parallelism", &tr::WorldConfig::getTensorParallelism)
|
||||
.def_property_readonly("pipeline_parallelism", &tr::WorldConfig::getPipelineParallelism)
|
||||
.def_property_readonly("context_parallelism", &tr::WorldConfig::getContextParallelism)
|
||||
.def_property_readonly("is_tensor_parallel", &tr::WorldConfig::isTensorParallel)
|
||||
.def_property_readonly("is_pipeline_parallel", &tr::WorldConfig::isPipelineParallel)
|
||||
.def_property_readonly("is_context_parallel", &tr::WorldConfig::isContextParallel)
|
||||
.def_property_readonly("rank", &tr::WorldConfig::getRank)
|
||||
.def_property_readonly("local_rank", &tr::WorldConfig::getLocalRank)
|
||||
.def_property_readonly("node_rank", &tr::WorldConfig::getNodeRank)
|
||||
.def_property_readonly("gpus_per_node", &tr::WorldConfig::getGpusPerNode)
|
||||
.def_property_readonly("gpus_per_group", &tr::WorldConfig::getGpusPerGroup)
|
||||
.def_property_readonly("device", &tr::WorldConfig::getDevice)
|
||||
.def_property_readonly("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank)
|
||||
.def_property_readonly("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank)
|
||||
.def_property_readonly("context_parallel_rank", &tr::WorldConfig::getContextParallelRank)
|
||||
.def_property_readonly("enable_attention_dp", &tr::WorldConfig::enableAttentionDP)
|
||||
.def_static("mpi",
|
||||
py::overload_cast<SizeType32, std::optional<SizeType32>, std::optional<SizeType32>,
|
||||
std::optional<SizeType32>, std::optional<std::vector<SizeType32>> const&, bool>(&tr::WorldConfig::mpi),
|
||||
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, py::arg("tensor_parallelism") = py::none(),
|
||||
py::arg("pipeline_parallelism") = py::none(), py::arg("context_parallelism") = py::none(),
|
||||
py::arg("device_ids") = py::none(), py::arg("enable_attention_dp") = false);
|
||||
|
||||
auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> py::tuple
|
||||
{
|
||||
return py::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty,
|
||||
config.presencePenalty, config.frequencyPenalty, config.promptIgnoreLength, config.topK, config.topP,
|
||||
config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate,
|
||||
config.lengthPenalty, config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences,
|
||||
config.minP, config.beamWidthArray);
|
||||
};
|
||||
auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig
|
||||
{
|
||||
if (t.size() != 20)
|
||||
{
|
||||
throw std::runtime_error("Invalid SamplingConfig state!");
|
||||
}
|
||||
|
||||
tr::SamplingConfig config;
|
||||
config.beamWidth = t[0].cast<SizeType32>();
|
||||
config.temperature = t[1].cast<OptVec<float>>();
|
||||
config.minLength = t[2].cast<OptVec<SizeType32>>();
|
||||
config.repetitionPenalty = t[3].cast<OptVec<float>>();
|
||||
config.presencePenalty = t[4].cast<OptVec<float>>();
|
||||
config.frequencyPenalty = t[5].cast<OptVec<float>>();
|
||||
config.promptIgnoreLength = t[6].cast<OptVec<SizeType32>>();
|
||||
config.topK = t[7].cast<OptVec<SizeType32>>();
|
||||
config.topP = t[8].cast<OptVec<float>>();
|
||||
config.randomSeed = t[9].cast<OptVec<uint64_t>>();
|
||||
config.topPDecay = t[10].cast<OptVec<float>>();
|
||||
config.topPMin = t[11].cast<OptVec<float>>();
|
||||
config.topPResetIds = t[12].cast<OptVec<TokenIdType>>();
|
||||
config.beamSearchDiversityRate = t[13].cast<OptVec<float>>();
|
||||
config.lengthPenalty = t[14].cast<OptVec<float>>();
|
||||
config.earlyStopping = t[15].cast<OptVec<SizeType32>>();
|
||||
config.noRepeatNgramSize = t[16].cast<OptVec<SizeType32>>();
|
||||
config.numReturnSequences = t[17].cast<SizeType32>();
|
||||
config.minP = t[18].cast<OptVec<float>>();
|
||||
config.beamWidthArray = t[19].cast<OptVec<std::vector<SizeType32>>>();
|
||||
|
||||
return config;
|
||||
};
|
||||
|
||||
py::classh<tr::SamplingConfig>(m, "SamplingConfig")
|
||||
.def(py::init<SizeType32>(), py::arg("beam_width") = 1)
|
||||
.def(py::init<tle::SamplingConfig, std::optional<tle::ExternalDraftTokensConfig>>(),
|
||||
py::arg("executor_sample_config"), py::arg("external_draft_tokens_config") = std::nullopt)
|
||||
.def_readwrite("beam_width", &tr::SamplingConfig::beamWidth)
|
||||
.def_readwrite("temperature", &tr::SamplingConfig::temperature)
|
||||
.def_readwrite("min_length", &tr::SamplingConfig::minLength)
|
||||
.def_readwrite("repetition_penalty", &tr::SamplingConfig::repetitionPenalty)
|
||||
.def_readwrite("presence_penalty", &tr::SamplingConfig::presencePenalty)
|
||||
.def_readwrite("frequency_penalty", &tr::SamplingConfig::frequencyPenalty)
|
||||
.def_readwrite("prompt_ignore_length", &tr::SamplingConfig::promptIgnoreLength)
|
||||
.def_readwrite("top_k", &tr::SamplingConfig::topK)
|
||||
.def_readwrite("top_p", &tr::SamplingConfig::topP)
|
||||
.def_readwrite("random_seed", &tr::SamplingConfig::randomSeed)
|
||||
.def_readwrite("top_p_decay", &tr::SamplingConfig::topPDecay)
|
||||
.def_readwrite("top_p_min", &tr::SamplingConfig::topPMin)
|
||||
.def_readwrite("top_p_reset_ids", &tr::SamplingConfig::topPResetIds)
|
||||
.def_readwrite("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate)
|
||||
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty)
|
||||
.def_readwrite("early_stopping", &tr::SamplingConfig::earlyStopping)
|
||||
.def_readwrite("no_repeat_ngram_size", &tr::SamplingConfig::noRepeatNgramSize)
|
||||
.def_readwrite("num_return_sequences", &tr::SamplingConfig::numReturnSequences)
|
||||
.def_readwrite("min_p", &tr::SamplingConfig::minP)
|
||||
.def_readwrite("beam_width_array", &tr::SamplingConfig::beamWidthArray)
|
||||
.def_readwrite("normalize_log_probs", &tr::SamplingConfig::normalizeLogProbs)
|
||||
.def(py::pickle(SamplingConfigGetState, SamplingConfigSetState))
|
||||
.def("__eq__", &tr::SamplingConfig::operator==);
|
||||
|
||||
m.def("make_sampling_config", &makeSamplingConfig, py::arg("configs"));
|
||||
|
||||
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
|
||||
.def(py::init<std::string, std::string, std::string, SizeType32, SizeType32, SizeType32, SizeType32,
|
||||
tr::ModelConfig, std::optional<tr::RuntimeDefaults>>(),
|
||||
py::arg("name"), py::arg("version"), py::arg("precision"), py::arg("tensor_parallelism"),
|
||||
py::arg("pipeline_parallelism"), py::arg("context_parallelism"), py::arg("gpus_per_node"),
|
||||
py::arg("model_config"), py::arg("runtime_defaults") = py::none())
|
||||
.def_static("parse", py::overload_cast<std::string const&>(&tr::GptJsonConfig::parse), py::arg("json"))
|
||||
.def_static(
|
||||
"parse_file", py::overload_cast<std::filesystem::path const&>(&tr::GptJsonConfig::parse), py::arg("path"))
|
||||
.def_property_readonly("model_config", &tr::GptJsonConfig::getModelConfig)
|
||||
.def_property_readonly("name", &tr::GptJsonConfig::getName)
|
||||
.def_property_readonly("version", &tr::GptJsonConfig::getVersion)
|
||||
.def_property_readonly("precision", &tr::GptJsonConfig::getPrecision)
|
||||
.def_property_readonly("tensor_parallelism", &tr::GptJsonConfig::getTensorParallelism)
|
||||
.def_property_readonly("pipeline_parallelism", &tr::GptJsonConfig::getPipelineParallelism)
|
||||
.def_property_readonly("context_parallelism", &tr::GptJsonConfig::getContextParallelism)
|
||||
.def_property_readonly("gpus_per_node", &tr::GptJsonConfig::getGpusPerNode)
|
||||
.def_property_readonly("world_size", &tr::GptJsonConfig::getWorldSize)
|
||||
.def_property_readonly("runtime_defaults", &tr::GptJsonConfig::getRuntimeDefaults)
|
||||
.def("engine_filename",
|
||||
py::overload_cast<tr::WorldConfig const&, std::string const&>(
|
||||
&tr::GptJsonConfig::engineFilename, py::const_),
|
||||
py::arg("world_config"), py::arg("model"))
|
||||
.def("engine_filename",
|
||||
py::overload_cast<tr::WorldConfig const&>(&tr::GptJsonConfig::engineFilename, py::const_),
|
||||
py::arg("world_config"));
|
||||
|
||||
py::enum_<tb::LlmRequestState>(m, "LlmRequestState")
|
||||
.value("UNKNOWN", tb::LlmRequestState::kUNKNOWN)
|
||||
.value("ENCODER_INIT", tb::LlmRequestState::kENCODER_INIT)
|
||||
.value("CONTEXT_INIT", tb::LlmRequestState::kCONTEXT_INIT)
|
||||
.value("GENERATION_IN_PROGRESS", tb::LlmRequestState::kGENERATION_IN_PROGRESS)
|
||||
.value("GENERATION_TO_COMPLETE", tb::LlmRequestState::kGENERATION_TO_COMPLETE)
|
||||
.value("GENERATION_COMPLETE", tb::LlmRequestState::kGENERATION_COMPLETE)
|
||||
.value("DISAGG_GENERATION_INIT", tb::LlmRequestState::kDISAGG_GENERATION_INIT)
|
||||
.value("DISAGG_CONTEXT_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS)
|
||||
.value("DISAGG_CONTEXT_COMPLETE", tb::LlmRequestState::kDISAGG_CONTEXT_COMPLETE)
|
||||
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
|
||||
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR)
|
||||
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
|
||||
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS)
|
||||
.value("DISAGG_CONTEXT_WAIT_SCHEDULER", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER)
|
||||
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS);
|
||||
|
||||
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
|
||||
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)
|
||||
.def_property_readonly("gpu", &tr::MemoryCounters::getGpu)
|
||||
.def_property_readonly("cpu", &tr::MemoryCounters::getCpu)
|
||||
.def_property_readonly("pinned", &tr::MemoryCounters::getPinned)
|
||||
.def_property_readonly("uvm", &tr::MemoryCounters::getUVM);
|
||||
|
||||
tensorrt_llm::pybind::process_group::initBindings(mInternalProcessGroup);
|
||||
tpb::Buffers::initBindings(mInternalBatchManager);
|
||||
tensorrt_llm::pybind::runtime::initBindings(mInternalRuntime);
|
||||
tensorrt_llm::pybind::testing::initBindings(mInternalTesting);
|
||||
tpb::initBindings(mInternalBatchManager);
|
||||
tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager);
|
||||
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager);
|
||||
tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager);
|
||||
tb::CacheTransceiverBindings::initBindings(mInternalBatchManager);
|
||||
tb::kv_cache_manager_v2::KVCacheManagerV2UtilsBindings::initBindings(mInternalBatchManagerKvCacheV2Utils);
|
||||
|
||||
auto mInternalAlgorithms = mInternal.def_submodule("algorithms", "Algorithms internal bindings");
|
||||
tpb::algorithms::initBindings(mInternalAlgorithms);
|
||||
|
||||
auto mUserbuffers = mInternal.def_submodule("userbuffers", "User buffers internal bindings");
|
||||
tensorrt_llm::kernels::userbuffers::UserBufferBindings::initBindings(mUserbuffers);
|
||||
|
||||
// NVLS allocators
|
||||
py::class_<tr::IpcNvlsHandle>(m, "IpcNvlsHandle")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("uc_ptr", &tr::IpcNvlsHandle::uc_ptr)
|
||||
.def_readwrite("mc_ptr", &tr::IpcNvlsHandle::mc_ptr)
|
||||
.def_readwrite("size", &tr::IpcNvlsHandle::size)
|
||||
.def("get_ipc_ptrs",
|
||||
[](tr::IpcNvlsHandle& self) { return reinterpret_cast<uintptr_t>(self.ipc_uc_ptrs.data()); });
|
||||
|
||||
m.def("ipc_nvls_allocate", &tr::ipcNvlsAllocate, py::return_value_policy::reference);
|
||||
m.def("ipc_nvls_free", &tr::ipcNvlsFree);
|
||||
m.def("ipc_nvls_supported", &tr::ipcNvlsSupported);
|
||||
|
||||
m.def("steady_clock_now", []() { return std::chrono::steady_clock::now(); });
|
||||
}
|
||||
@ -1,94 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace PybindUtils
|
||||
{
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
template <typename T>
|
||||
void bindList(py::module& m, std::string const& name)
|
||||
{
|
||||
py::class_<T>(m, name.c_str())
|
||||
.def(py::init())
|
||||
.def("push_back", [](T& lst, const typename T::value_type& value) { lst.push_back(value); })
|
||||
.def("pop_back", [](T& lst) { lst.pop_back(); })
|
||||
.def("push_front", [](T& lst, const typename T::value_type& value) { lst.push_front(value); })
|
||||
.def("pop_front", [](T& lst) { lst.pop_front(); })
|
||||
.def("__len__", [](T const& lst) { return lst.size(); })
|
||||
.def(
|
||||
"__iter__", [](T& lst) { return py::make_iterator(lst.begin(), lst.end()); }, py::keep_alive<0, 1>())
|
||||
.def("__getitem__",
|
||||
[](T const& lst, size_t index)
|
||||
{
|
||||
if (index >= lst.size())
|
||||
throw py::index_error();
|
||||
auto it = lst.begin();
|
||||
std::advance(it, index);
|
||||
return *it;
|
||||
})
|
||||
.def("__setitem__",
|
||||
[](T& lst, size_t index, const typename T::value_type& value)
|
||||
{
|
||||
if (index >= lst.size())
|
||||
throw py::index_error();
|
||||
auto it = lst.begin();
|
||||
std::advance(it, index);
|
||||
*it = value;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void bindSet(py::module& m, std::string const& name)
|
||||
{
|
||||
py::class_<T>(m, name.c_str())
|
||||
.def(py::init())
|
||||
.def("clear", &T::clear)
|
||||
.def("size", &T::size)
|
||||
.def("insert", [](T& s, typename T::value_type const& value) { s.insert(value); })
|
||||
.def("erase", py::overload_cast<typename T::value_type const&>(&T::erase))
|
||||
.def("__len__", [](T const& lst) { return lst.size(); })
|
||||
.def("__contains__", [](T const& s, typename T::value_type x) { return s.find(x) != s.end(); })
|
||||
.def(
|
||||
"__iter__", [](T& s) { return py::make_iterator(s.begin(), s.end()); }, py::keep_alive<0, 1>())
|
||||
.def("__eq__", [](T const& s, T const& other) { return s == other; })
|
||||
.def(py::pickle(
|
||||
[](T const& s) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(std::vector<typename T::value_type>(s.begin(), s.end()));
|
||||
},
|
||||
[](py::tuple t) { // __setstate__
|
||||
if (t.size() != 1)
|
||||
throw std::runtime_error("Invalid state!");
|
||||
/* Create a new C++ instance */
|
||||
T s;
|
||||
/* Assign any additional state */
|
||||
auto state_list = t[0].cast<std::vector<typename T::value_type>>();
|
||||
for (auto& item : state_list)
|
||||
{
|
||||
s.insert(item);
|
||||
}
|
||||
return s;
|
||||
}));
|
||||
}
|
||||
|
||||
} // namespace PybindUtils
|
||||
@ -1,265 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/common.h"
|
||||
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
|
||||
#include "tensorrt_llm/common/optionalRef.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/detail/common.h>
|
||||
#include <pybind11/detail/descr.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
// Pybind requires to have a central include in order for type casters to work.
|
||||
// Opaque bindings add a type caster, so they have the same requirement.
|
||||
// See the warning in https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html
|
||||
|
||||
// Opaque bindings
|
||||
PYBIND11_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet)
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<tensorrt_llm::batch_manager::SlotDecoderBuffers>)
|
||||
|
||||
// Custom casters
|
||||
namespace PYBIND11_NAMESPACE
|
||||
{
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct type_caster<tensorrt_llm::common::OptionalRef<T>>
|
||||
{
|
||||
using value_conv = make_caster<T>;
|
||||
|
||||
PYBIND11_TYPE_CASTER(tensorrt_llm::common::OptionalRef<T>, value_conv::name);
|
||||
|
||||
bool load(handle src, bool convert)
|
||||
{
|
||||
if (src.is_none())
|
||||
{
|
||||
// If the Python object is None, create an empty OptionalRef
|
||||
value = tensorrt_llm::common::OptionalRef<T>();
|
||||
return true;
|
||||
}
|
||||
|
||||
value_conv conv;
|
||||
if (!conv.load(src, convert))
|
||||
return false;
|
||||
|
||||
// Create an OptionalRef with a reference to the converted value
|
||||
value = tensorrt_llm::common::OptionalRef<T>(conv);
|
||||
return true;
|
||||
}
|
||||
|
||||
static handle cast(tensorrt_llm::common::OptionalRef<T> const& src, return_value_policy policy, handle parent)
|
||||
{
|
||||
if (!src.has_value())
|
||||
return none().release();
|
||||
|
||||
return value_conv::cast(*src, policy, parent);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PathCaster
|
||||
{
|
||||
|
||||
private:
|
||||
static PyObject* unicode_from_fs_native(std::string const& w)
|
||||
{
|
||||
return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size()));
|
||||
}
|
||||
|
||||
static PyObject* unicode_from_fs_native(std::wstring const& w)
|
||||
{
|
||||
return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size()));
|
||||
}
|
||||
|
||||
public:
|
||||
static handle cast(T const& path, return_value_policy, handle)
|
||||
{
|
||||
if (auto py_str = unicode_from_fs_native(path.native()))
|
||||
{
|
||||
return module_::import("pathlib").attr("Path")(reinterpret_steal<object>(py_str)).release();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool load(handle handle, bool)
|
||||
{
|
||||
PyObject* native = nullptr;
|
||||
if constexpr (std::is_same_v<typename T::value_type, char>)
|
||||
{
|
||||
if (PyUnicode_FSConverter(handle.ptr(), &native) != 0)
|
||||
{
|
||||
if (auto* c_str = PyBytes_AsString(native))
|
||||
{
|
||||
// AsString returns a pointer to the internal buffer, which
|
||||
// must not be free'd.
|
||||
value = c_str;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename T::value_type, wchar_t>)
|
||||
{
|
||||
if (PyUnicode_FSDecoder(handle.ptr(), &native) != 0)
|
||||
{
|
||||
if (auto* c_str = PyUnicode_AsWideCharString(native, nullptr))
|
||||
{
|
||||
// AsWideCharString returns a new string that must be free'd.
|
||||
value = c_str; // Copies the string.
|
||||
PyMem_Free(c_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
Py_XDECREF(native);
|
||||
if (PyErr_Occurred())
|
||||
{
|
||||
PyErr_Clear();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(T, const_name("os.PathLike"));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<std::filesystem::path> : public PathCaster<std::filesystem::path>
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
class type_caster<tensorrt_llm::executor::StreamPtr>
|
||||
{
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(tensorrt_llm::executor::StreamPtr, _("int"));
|
||||
|
||||
bool load([[maybe_unused]] handle src, bool)
|
||||
{
|
||||
auto stream_ptr = src.cast<uintptr_t>();
|
||||
value = std::make_shared<tensorrt_llm::runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static handle cast(
|
||||
tensorrt_llm::executor::StreamPtr const& src, return_value_policy /* policy */, handle /* parent */)
|
||||
{
|
||||
// Return cudaStream_t as integer.
|
||||
return PyLong_FromVoidPtr(src->get());
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<tensorrt_llm::executor::Tensor>
|
||||
{
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(tensorrt_llm::executor::Tensor, _("torch.Tensor"));
|
||||
|
||||
// Convert PyObject(torch.Tensor) -> tensorrt_llm::executor::Tensor
|
||||
bool load(handle src, bool)
|
||||
{
|
||||
PyObject* obj = src.ptr();
|
||||
if (THPVariable_Check(obj))
|
||||
{
|
||||
at::Tensor const& t = THPVariable_Unpack(obj);
|
||||
value = tensorrt_llm::executor::detail::ofITensor(tensorrt_llm::runtime::TorchView::of(t));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Convert tensorrt_llm::executor::Tensor -> PyObject(torch.Tensor)
|
||||
static handle cast(tensorrt_llm::executor::Tensor const& src, return_value_policy /* policy */, handle /* parent */)
|
||||
{
|
||||
return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(tensorrt_llm::executor::detail::toITensor(src)));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<tensorrt_llm::runtime::ITensor::SharedPtr>
|
||||
{
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(tensorrt_llm::runtime::ITensor::SharedPtr, _("torch.Tensor"));
|
||||
|
||||
// Convert PyObject(torch.Tensor) -> tensorrt_llm::runtime::ITensor::SharedPtr
|
||||
bool load(handle src, bool)
|
||||
{
|
||||
PyObject* obj = src.ptr();
|
||||
if (THPVariable_Check(obj))
|
||||
{
|
||||
at::Tensor const& t = THPVariable_Unpack(obj);
|
||||
value = std::move(tensorrt_llm::runtime::TorchView::of(t));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Convert tensorrt_llm::runtime::ITensor::SharedPtr -> PyObject(torch.Tensor)
|
||||
static handle cast(
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr const& src, return_value_policy /* policy */, handle /* parent */)
|
||||
{
|
||||
if (src == nullptr)
|
||||
{
|
||||
return none().release();
|
||||
}
|
||||
return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(src));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<tensorrt_llm::runtime::ITensor::SharedConstPtr>
|
||||
{
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(tensorrt_llm::runtime::ITensor::SharedConstPtr, _("torch.Tensor"));
|
||||
|
||||
// Convert PyObject(torch.Tensor) -> tensorrt_llm::runtime::ITensor::SharedConstPtr
|
||||
bool load(handle src, bool)
|
||||
{
|
||||
PyObject* obj = src.ptr();
|
||||
if (THPVariable_Check(obj))
|
||||
{
|
||||
at::Tensor const& t = THPVariable_Unpack(obj);
|
||||
value = std::move(tensorrt_llm::runtime::TorchView::of(t));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Convert tensorrt_llm::runtime::ITensor::SharedConstPtr -> PyObject(torch.Tensor)
|
||||
static handle cast(tensorrt_llm::runtime::ITensor::SharedConstPtr const& src, return_value_policy /* policy */,
|
||||
handle /* parent */)
|
||||
{
|
||||
if (src == nullptr)
|
||||
{
|
||||
return none().release();
|
||||
}
|
||||
return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(
|
||||
reinterpret_cast<tensorrt_llm::runtime::ITensor::SharedPtr const&>(src)));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace PYBIND11_NAMESPACE
|
||||
@ -1,67 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tllmExceptions.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm::pybind::common
|
||||
{
|
||||
|
||||
void initExceptionsBindings(py::module_& m)
|
||||
{
|
||||
// Bind the RequestErrorCode enum
|
||||
py::enum_<tc::RequestErrorCode>(m, "RequestErrorCode")
|
||||
.value("UNKNOWN_ERROR", tc::RequestErrorCode::kUNKNOWN_ERROR)
|
||||
.value("NETWORK_ERROR", tc::RequestErrorCode::kNETWORK_ERROR)
|
||||
.export_values();
|
||||
|
||||
// Create the RequestSpecificException Python exception class
|
||||
static PyObject* request_specific_exc
|
||||
= PyErr_NewException("tensorrt_llm.RequestSpecificException", nullptr, nullptr);
|
||||
|
||||
// Add attributes to the Python exception class
|
||||
py::handle(request_specific_exc).attr("request_id") = py::none();
|
||||
py::handle(request_specific_exc).attr("error_code") = py::none();
|
||||
|
||||
m.add_object("RequestSpecificException", py::handle(request_specific_exc));
|
||||
|
||||
// Register exception translator to convert C++ exceptions to Python
|
||||
py::register_exception_translator(
|
||||
[](std::exception_ptr p)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (p)
|
||||
std::rethrow_exception(p);
|
||||
}
|
||||
catch (const tc::RequestSpecificException& e)
|
||||
{
|
||||
// Create a Python exception with the request ID and error code information
|
||||
py::object msg = py::str(e.what());
|
||||
py::object inst = py::reinterpret_steal<py::object>(
|
||||
PyObject_CallFunctionObjArgs(request_specific_exc, msg.ptr(), nullptr));
|
||||
|
||||
inst.attr("request_id") = py::cast(e.getRequestId());
|
||||
inst.attr("error_code") = py::cast(e.getErrorCode());
|
||||
|
||||
PyErr_SetObject(request_specific_exc, inst.ptr());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::common
|
||||
@ -1,32 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm::pybind::common
|
||||
{
|
||||
|
||||
/// @brief Bind RequestSpecificException and related types to Python
|
||||
/// @param m The pybind11 module to bind to
|
||||
void initExceptionsBindings(py::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::common
|
||||
@ -1,280 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
#include "executor.h"
|
||||
#include "executorConfig.h"
|
||||
#include "request.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
using SizeType32 = tle::SizeType32;
|
||||
|
||||
namespace tensorrt_llm::pybind::executor
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void instantiateEventDiff(pybind11::module& m, std::string const& name)
|
||||
{
|
||||
py::class_<tle::KVCacheEventDiff<T>>(m, ("KVCacheEventDiff" + name).c_str())
|
||||
.def_readonly("old_value", &tle::KVCacheEventDiff<T>::oldValue)
|
||||
.def_readonly("new_value", &tle::KVCacheEventDiff<T>::newValue);
|
||||
}
|
||||
|
||||
void initBindings(pybind11::module_& m)
|
||||
{
|
||||
m.attr("__version__") = tle::version();
|
||||
py::enum_<tle::ModelType>(m, "ModelType")
|
||||
.value("DECODER_ONLY", tle::ModelType::kDECODER_ONLY)
|
||||
.value("ENCODER_ONLY", tle::ModelType::kENCODER_ONLY)
|
||||
.value("ENCODER_DECODER", tle::ModelType::kENCODER_DECODER);
|
||||
|
||||
auto decodingModeGetstate = [](tle::DecodingMode const& self) { return py::make_tuple(self.getState()); };
|
||||
auto decodingModeSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 1)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::DecodingMode(state[0].cast<tle::DecodingMode::UnderlyingType>());
|
||||
};
|
||||
py::class_<tle::DecodingMode>(m, "DecodingMode")
|
||||
.def("Auto", &tle::DecodingMode::Auto)
|
||||
.def("TopK", &tle::DecodingMode::TopK)
|
||||
.def("TopP", &tle::DecodingMode::TopP)
|
||||
.def("TopKTopP", &tle::DecodingMode::TopKTopP)
|
||||
.def("BeamSearch", &tle::DecodingMode::BeamSearch)
|
||||
.def("Medusa", &tle::DecodingMode::Medusa)
|
||||
.def("Lookahead", &tle::DecodingMode::Lookahead)
|
||||
.def("ExplicitDraftTokens", &tle::DecodingMode::ExplicitDraftTokens)
|
||||
.def("Eagle", &tle::DecodingMode::Eagle)
|
||||
.def("isAuto", &tle::DecodingMode::isAuto)
|
||||
.def("isTopK", &tle::DecodingMode::isTopK)
|
||||
.def("isTopP", &tle::DecodingMode::isTopP)
|
||||
.def("isTopKorTopP", &tle::DecodingMode::isTopKorTopP)
|
||||
.def("isTopKandTopP", &tle::DecodingMode::isTopKandTopP)
|
||||
.def("isBeamSearch", &tle::DecodingMode::isBeamSearch)
|
||||
.def("isMedusa", &tle::DecodingMode::isMedusa)
|
||||
.def("isLookahead", &tle::DecodingMode::isLookahead)
|
||||
.def("isExplicitDraftTokens", &tle::DecodingMode::isExplicitDraftTokens)
|
||||
.def("isEagle", &tle::DecodingMode::isEagle)
|
||||
.def("useVariableBeamWidthSearch", &tle::DecodingMode::useVariableBeamWidthSearch)
|
||||
.def_property_readonly("name", &tle::DecodingMode::getName)
|
||||
.def(py::pickle(decodingModeGetstate, decodingModeSetstate));
|
||||
|
||||
py::enum_<tle::CapacitySchedulerPolicy>(m, "CapacitySchedulerPolicy")
|
||||
.value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION)
|
||||
.value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT)
|
||||
.value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH);
|
||||
|
||||
py::enum_<tle::ContextChunkingPolicy>(m, "ContextChunkingPolicy")
|
||||
.value("EQUAL_PROGRESS", tle::ContextChunkingPolicy::kEQUAL_PROGRESS)
|
||||
.value("FIRST_COME_FIRST_SERVED", tle::ContextChunkingPolicy::kFIRST_COME_FIRST_SERVED);
|
||||
|
||||
py::enum_<tle::CommunicationType>(m, "CommunicationType").value("MPI", tle::CommunicationType::kMPI);
|
||||
|
||||
py::enum_<tle::CommunicationMode>(m, "CommunicationMode")
|
||||
.value("LEADER", tle::CommunicationMode::kLEADER)
|
||||
.value("ORCHESTRATOR", tle::CommunicationMode::kORCHESTRATOR);
|
||||
|
||||
py::class_<tle::KvCacheStats>(m, "KvCacheStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("max_num_blocks", &tle::KvCacheStats::maxNumBlocks)
|
||||
.def_readwrite("free_num_blocks", &tle::KvCacheStats::freeNumBlocks)
|
||||
.def_readwrite("used_num_blocks", &tle::KvCacheStats::usedNumBlocks)
|
||||
.def_readwrite("tokens_per_block", &tle::KvCacheStats::tokensPerBlock)
|
||||
.def_readwrite("alloc_total_blocks", &tle::KvCacheStats::allocTotalBlocks)
|
||||
.def_readwrite("alloc_new_blocks", &tle::KvCacheStats::allocNewBlocks)
|
||||
.def_readwrite("reused_blocks", &tle::KvCacheStats::reusedBlocks)
|
||||
.def_readwrite("missed_blocks", &tle::KvCacheStats::missedBlocks)
|
||||
.def_readwrite("cache_hit_rate", &tle::KvCacheStats::cacheHitRate);
|
||||
|
||||
py::class_<tle::StaticBatchingStats>(m, "StaticBatchingStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("num_scheduled_requests", &tle::StaticBatchingStats::numScheduledRequests)
|
||||
.def_readwrite("num_context_requests", &tle::StaticBatchingStats::numContextRequests)
|
||||
.def_readwrite("num_ctx_tokens", &tle::StaticBatchingStats::numCtxTokens)
|
||||
.def_readwrite("num_gen_tokens", &tle::StaticBatchingStats::numGenTokens)
|
||||
.def_readwrite("empty_gen_slots", &tle::StaticBatchingStats::emptyGenSlots);
|
||||
|
||||
py::class_<tle::InflightBatchingStats>(m, "InflightBatchingStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("num_scheduled_requests", &tle::InflightBatchingStats::numScheduledRequests)
|
||||
.def_readwrite("num_context_requests", &tle::InflightBatchingStats::numContextRequests)
|
||||
.def_readwrite("num_gen_requests", &tle::InflightBatchingStats::numGenRequests)
|
||||
.def_readwrite("num_paused_requests", &tle::InflightBatchingStats::numPausedRequests)
|
||||
.def_readwrite("num_ctx_tokens", &tle::InflightBatchingStats::numCtxTokens)
|
||||
.def_readwrite("micro_batch_id", &tle::InflightBatchingStats::microBatchId)
|
||||
.def_readwrite("avg_num_decoded_tokens_per_iter", &tle::InflightBatchingStats::avgNumDecodedTokensPerIter);
|
||||
|
||||
py::class_<tle::SpecDecodingStats>(m, "SpecDecodingStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("num_draft_tokens", &tle::SpecDecodingStats::numDraftTokens)
|
||||
.def_readwrite("num_accepted_tokens", &tle::SpecDecodingStats::numAcceptedTokens)
|
||||
.def_readwrite("num_requests_with_draft_tokens", &tle::SpecDecodingStats::numRequestsWithDraftTokens)
|
||||
.def_readwrite("acceptance_length", &tle::SpecDecodingStats::acceptanceLength)
|
||||
.def_readwrite("iter_latency_ms", &tle::SpecDecodingStats::iterLatencyMS)
|
||||
.def_readwrite("draft_overhead", &tle::SpecDecodingStats::draftOverhead);
|
||||
|
||||
py::class_<tle::IterationStats>(m, "IterationStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("timestamp", &tle::IterationStats::timestamp)
|
||||
.def_readwrite("iter", &tle::IterationStats::iter)
|
||||
.def_readwrite("iter_latency_ms", &tle::IterationStats::iterLatencyMS)
|
||||
.def_readwrite("new_active_requests_queue_latency_ms", &tle::IterationStats::newActiveRequestsQueueLatencyMS)
|
||||
.def_readwrite("num_new_active_requests", &tle::IterationStats::numNewActiveRequests)
|
||||
.def_readwrite("num_active_requests", &tle::IterationStats::numActiveRequests)
|
||||
.def_readwrite("num_queued_requests", &tle::IterationStats::numQueuedRequests)
|
||||
.def_readwrite("num_completed_requests", &tle::IterationStats::numCompletedRequests)
|
||||
.def_readwrite("max_num_active_requests", &tle::IterationStats::maxNumActiveRequests)
|
||||
.def_readwrite("gpu_mem_usage", &tle::IterationStats::gpuMemUsage)
|
||||
.def_readwrite("cpu_mem_usage", &tle::IterationStats::cpuMemUsage)
|
||||
.def_readwrite("pinned_mem_usage", &tle::IterationStats::pinnedMemUsage)
|
||||
.def_readwrite("kv_cache_stats", &tle::IterationStats::kvCacheStats)
|
||||
.def_readwrite("cross_kv_cache_stats", &tle::IterationStats::crossKvCacheStats)
|
||||
.def_readwrite("static_batching_stats", &tle::IterationStats::staticBatchingStats)
|
||||
.def_readwrite("inflight_batching_stats", &tle::IterationStats::inflightBatchingStats)
|
||||
.def_readwrite("specdec_stats", &tle::IterationStats::specDecodingStats)
|
||||
.def("to_json_str",
|
||||
[](tle::IterationStats const& iterationStats)
|
||||
{ return tle::JsonSerialization::toJsonStr(iterationStats); });
|
||||
|
||||
py::class_<tle::DebugTensorsPerIteration>(m, "DebugTensorsPerIteration")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("iter", &tle::DebugTensorsPerIteration::iter)
|
||||
.def_readwrite("debug_tensors", &tle::DebugTensorsPerIteration::debugTensors);
|
||||
|
||||
py::enum_<tle::RequestStage>(m, "RequestStage")
|
||||
.value("QUEUED", tle::RequestStage::kQUEUED)
|
||||
.value("ENCODER_IN_PROGRESS", tle::RequestStage::kENCODER_IN_PROGRESS)
|
||||
.value("CONTEXT_IN_PROGRESS", tle::RequestStage::kCONTEXT_IN_PROGRESS)
|
||||
.value("GENERATION_IN_PROGRESS", tle::RequestStage::kGENERATION_IN_PROGRESS)
|
||||
.value("GENERATION_COMPLETE", tle::RequestStage::kGENERATION_COMPLETE);
|
||||
|
||||
py::class_<tle::DisServingRequestStats>(m, "DisServingRequestStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("kv_cache_transfer_ms", &tle::DisServingRequestStats::kvCacheTransferMS)
|
||||
.def_readwrite("kv_cache_size", &tle::DisServingRequestStats::kvCacheSize);
|
||||
|
||||
py::class_<tle::RequestStats>(m, "RequestStats")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("id", &tle::RequestStats::id)
|
||||
.def_readwrite("stage", &tle::RequestStats::stage)
|
||||
.def_readwrite("context_prefill_position", &tle::RequestStats::contextPrefillPosition)
|
||||
.def_readwrite("num_generated_tokens", &tle::RequestStats::numGeneratedTokens)
|
||||
.def_readwrite("avg_num_decoded_tokens_per_iter", &tle::RequestStats::avgNumDecodedTokensPerIter)
|
||||
.def_readwrite("scheduled", &tle::RequestStats::scheduled)
|
||||
.def_readwrite("paused", &tle::RequestStats::paused)
|
||||
.def_readwrite("dis_serving_stats", &tle::RequestStats::disServingStats)
|
||||
.def_readwrite("alloc_total_blocks_per_request", &tle::RequestStats::allocTotalBlocksPerRequest)
|
||||
.def_readwrite("alloc_new_blocks_per_request", &tle::RequestStats::allocNewBlocksPerRequest)
|
||||
.def_readwrite("reused_blocks_per_request", &tle::RequestStats::reusedBlocksPerRequest)
|
||||
.def_readwrite("missed_blocks_per_request", &tle::RequestStats::missedBlocksPerRequest)
|
||||
.def_readwrite("kv_cache_hit_rate_per_request", &tle::RequestStats::kvCacheHitRatePerRequest)
|
||||
.def("to_json_str",
|
||||
[](tle::RequestStats const& iterationStats) { return tle::JsonSerialization::toJsonStr(iterationStats); });
|
||||
|
||||
py::class_<tle::RequestStatsPerIteration>(m, "RequestStatsPerIteration")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("iter", &tle::RequestStatsPerIteration::iter)
|
||||
.def_readwrite("request_stats", &tle::RequestStatsPerIteration::requestStats)
|
||||
.def("to_json_str",
|
||||
[](tle::RequestStatsPerIteration const& iterationStats)
|
||||
{ return tle::JsonSerialization::toJsonStr(iterationStats); });
|
||||
|
||||
py::module_ executor_kv_cache = m.def_submodule("kv_cache", "Executor KV Cache Manager");
|
||||
|
||||
py::class_<tle::KVCacheCreatedData>(executor_kv_cache, "KVCacheCreatedData")
|
||||
.def_readonly("num_blocks_per_cache_level", &tle::KVCacheCreatedData::numBlocksPerCacheLevel);
|
||||
|
||||
py::class_<tensorrt_llm::runtime::UniqueToken>(executor_kv_cache, "UniqueToken")
|
||||
.def_readonly("token_id", &tensorrt_llm::runtime::UniqueToken::tokenId)
|
||||
.def_readonly("token_extra_id", &tensorrt_llm::runtime::UniqueToken::tokenExtraId);
|
||||
|
||||
py::class_<tle::KVCacheStoredBlockData>(executor_kv_cache, "KVCacheStoredBlockData")
|
||||
.def_readonly("block_hash", &tle::KVCacheStoredBlockData::blockHash)
|
||||
.def_readonly("tokens", &tle::KVCacheStoredBlockData::tokens)
|
||||
.def_readonly("lora_id", &tle::KVCacheStoredBlockData::loraId)
|
||||
.def_readonly("cache_level", &tle::KVCacheStoredBlockData::cacheLevel)
|
||||
.def_readonly("priority", &tle::KVCacheStoredBlockData::priority)
|
||||
.def_property_readonly("mm_keys",
|
||||
[](tle::KVCacheStoredBlockData const& self)
|
||||
{
|
||||
// Convert std::vector<MmKey> to Python list of tuples (bytes, int)
|
||||
// MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
|
||||
py::list result;
|
||||
for (auto const& mmKey : self.mmKeys)
|
||||
{
|
||||
auto const& hashArray = mmKey.first;
|
||||
auto offset = mmKey.second;
|
||||
py::bytes hashBytes(reinterpret_cast<char const*>(hashArray.data()), hashArray.size());
|
||||
result.append(py::make_tuple(hashBytes, offset));
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
py::class_<tle::KVCacheStoredData>(executor_kv_cache, "KVCacheStoredData")
|
||||
.def_readonly("parent_hash", &tle::KVCacheStoredData::parentHash)
|
||||
.def_readonly("blocks", &tle::KVCacheStoredData::blocks);
|
||||
|
||||
py::class_<tle::KVCacheRemovedData>(executor_kv_cache, "KVCacheRemovedData")
|
||||
.def_readonly("block_hashes", &tle::KVCacheRemovedData::blockHashes);
|
||||
|
||||
instantiateEventDiff<SizeType32>(executor_kv_cache, "Int");
|
||||
|
||||
py::class_<tle::KVCacheUpdatedData>(executor_kv_cache, "KVCacheUpdatedData")
|
||||
.def_readonly("block_hash", &tle::KVCacheUpdatedData::blockHash)
|
||||
.def_readonly("cache_level", &tle::KVCacheUpdatedData::cacheLevel)
|
||||
.def_readonly("priority", &tle::KVCacheUpdatedData::priority);
|
||||
|
||||
py::class_<tle::KVCacheEvent>(executor_kv_cache, "KVCacheEvent")
|
||||
.def_readonly("event_id", &tle::KVCacheEvent::eventId)
|
||||
.def_readonly("data", &tle::KVCacheEvent::data)
|
||||
.def_readonly("window_size", &tle::KVCacheEvent::windowSize)
|
||||
.def_readonly("attention_dp_rank", &tle::KVCacheEvent::attentionDpRank);
|
||||
|
||||
py::class_<tle::KVCacheEventManager, std::shared_ptr<tle::KVCacheEventManager>>(
|
||||
executor_kv_cache, "KVCacheEventManager")
|
||||
.def(
|
||||
"get_latest_events",
|
||||
[](tle::KVCacheEventManager& self, std::optional<double> timeout_ms = std::nullopt)
|
||||
{
|
||||
if (timeout_ms)
|
||||
{
|
||||
return self.getLatestEvents(std::chrono::milliseconds(static_cast<int64_t>(*timeout_ms)));
|
||||
}
|
||||
return self.getLatestEvents(std::nullopt);
|
||||
},
|
||||
py::arg("timeout_ms") = std::nullopt);
|
||||
|
||||
tensorrt_llm::pybind::executor::initRequestBindings(m);
|
||||
tensorrt_llm::pybind::executor::initConfigBindings(m);
|
||||
tensorrt_llm::pybind::executor::Executor::initBindings(m);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::executor
|
||||
@ -1,29 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::executor
|
||||
{
|
||||
|
||||
// Register bindings for executor API.
|
||||
void initBindings(pybind11::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::executor
|
||||
@ -1,191 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "executor.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/tensor.h"
|
||||
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
namespace
|
||||
{
|
||||
tle::Tensor numpyToTensor(py::array const& array)
|
||||
{
|
||||
auto npDtype = array.dtype();
|
||||
tle::DataType dtype;
|
||||
if (npDtype.is(py::dtype("float16")))
|
||||
{
|
||||
dtype = tle::DataType::kFP16;
|
||||
}
|
||||
else if (npDtype.is(py::dtype("float32")))
|
||||
{
|
||||
dtype = tle::DataType::kFP32;
|
||||
}
|
||||
else if (npDtype.is(py::dtype("int8")))
|
||||
{
|
||||
dtype = tle::DataType::kINT8;
|
||||
}
|
||||
else if (npDtype.is(py::dtype("int32")))
|
||||
{
|
||||
dtype = tle::DataType::kINT32;
|
||||
}
|
||||
else if (npDtype.is(py::dtype("int64")))
|
||||
{
|
||||
dtype = tle::DataType::kINT64;
|
||||
}
|
||||
else if (npDtype.attr("kind").cast<std::string>() == "V" && npDtype.attr("itemsize").cast<int>() == 1
|
||||
&& npDtype.attr("metadata")["dtype"].cast<std::string>() == "float8")
|
||||
{
|
||||
dtype = tle::DataType::kFP8;
|
||||
}
|
||||
else if (npDtype.attr("kind").cast<std::string>() == "V" && npDtype.attr("itemsize").cast<int>() == 2
|
||||
&& npDtype.attr("metadata")["dtype"].cast<std::string>() == "bfloat16")
|
||||
{
|
||||
dtype = tle::DataType::kBF16;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Unsupported numpy dtype: " + npDtype.attr("name").cast<std::string>());
|
||||
}
|
||||
|
||||
tle::Shape shape(array.shape(), array.ndim());
|
||||
|
||||
return tle::Tensor::of(dtype, const_cast<void*>(array.data()), shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace tensorrt_llm::pybind::executor
|
||||
{
|
||||
|
||||
Executor::Executor(
|
||||
std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig)
|
||||
{
|
||||
mExecutor = std::make_unique<tle::Executor>(modelPath, modelType, executorConfig);
|
||||
}
|
||||
|
||||
Executor::Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath,
|
||||
tle::ModelType modelType, tle::ExecutorConfig const& executorConfig)
|
||||
{
|
||||
mExecutor = std::make_unique<tle::Executor>(encoderModelPath, decoderModelPath, modelType, executorConfig);
|
||||
}
|
||||
|
||||
Executor::Executor(pybind11::buffer engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType,
|
||||
tle::ExecutorConfig const& executorConfig, std::optional<pybind11::dict> managedWeights)
|
||||
{
|
||||
py::buffer_info info = engineBuffer.request();
|
||||
uint8_t const* data = reinterpret_cast<uint8_t const*>(info.ptr);
|
||||
size_t size = info.size;
|
||||
std::optional<std::map<std::string, tle::Tensor>> managedWeightsMap = std::nullopt;
|
||||
if (managedWeights.has_value() && !managedWeights.value().empty())
|
||||
{
|
||||
managedWeightsMap = std::map<std::string, tle::Tensor>();
|
||||
for (auto const& item : managedWeights.value())
|
||||
{
|
||||
std::string name = item.first.cast<std::string>();
|
||||
py::array array = item.second.cast<py::array>();
|
||||
managedWeightsMap->emplace(name, numpyToTensor(array));
|
||||
}
|
||||
}
|
||||
mExecutor = std::make_unique<tle::Executor>(
|
||||
tle::BufferView(data, size), jsonConfigStr, modelType, executorConfig, managedWeightsMap);
|
||||
}
|
||||
|
||||
Executor::Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr,
|
||||
std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType,
|
||||
tle::ExecutorConfig const& executorConfig)
|
||||
{
|
||||
uint8_t const* encoderData = reinterpret_cast<uint8_t const*>(encoderEngineBuffer.data());
|
||||
size_t encoderSize = encoderEngineBuffer.size();
|
||||
uint8_t const* decoderData = reinterpret_cast<uint8_t const*>(decoderEngineBuffer.data());
|
||||
size_t decoderSize = decoderEngineBuffer.size();
|
||||
mExecutor = std::make_unique<tle::Executor>(tle::BufferView(encoderData, encoderSize), encoderJsonConfigStr,
|
||||
tle::BufferView(decoderData, decoderSize), decoderJsonConfigStr, modelType, executorConfig);
|
||||
}
|
||||
|
||||
py::object Executor::enter()
|
||||
{
|
||||
TLLM_CHECK(static_cast<bool>(mExecutor));
|
||||
return py::cast(this);
|
||||
}
|
||||
|
||||
void Executor::exit(
|
||||
[[maybe_unused]] py::handle type, [[maybe_unused]] py::handle value, [[maybe_unused]] py::handle traceback)
|
||||
{
|
||||
shutdown();
|
||||
mExecutor = nullptr;
|
||||
}
|
||||
|
||||
void Executor::shutdown()
|
||||
{
|
||||
// NOTE: we must release the GIL here. Executor has spawned a thread for the execution loop. That thread must be
|
||||
// able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so
|
||||
// we release it now. Note that we shouldn't do anything related to python objects after that.
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
py::gil_scoped_release release;
|
||||
mExecutor->shutdown();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void Executor::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<Executor>(m, "Executor")
|
||||
.def(py::init<std::filesystem::path const&, tle::ModelType, tle::ExecutorConfig const&>(),
|
||||
py::arg("model_path"), py::arg("model_type"), py::arg("executor_config"))
|
||||
.def(py::init<std::filesystem::path const&, std::filesystem::path const&, tle::ModelType,
|
||||
tle::ExecutorConfig const&>(),
|
||||
py::arg("encoder_model_path"), py::arg("decoder_model_path"), py::arg("model_type"),
|
||||
py::arg("executor_config"))
|
||||
.def(py::init<py::buffer, std::string const&, tle::ModelType, tle::ExecutorConfig const&, py::dict>(),
|
||||
py::arg("engine_buffer"), py::arg("json_config_str"), py::arg("model_type"), py::arg("executor_config"),
|
||||
py::arg("managed_weights") = py::dict())
|
||||
.def(py::init<std::string const&, std::string const&, std::string const&, std::string const&, tle::ModelType,
|
||||
tle::ExecutorConfig const&>(),
|
||||
py::arg("encoder_engine_buffer"), py::arg("encoder_json_config_str"), py::arg("decoder_engine_buffer"),
|
||||
py::arg("decoder_json_config_str"), py::arg("model_type"), py::arg("executor_config"))
|
||||
.def("shutdown", &Executor::shutdown)
|
||||
.def("__enter__", &Executor::enter)
|
||||
.def("__exit__", &Executor::exit)
|
||||
.def("enqueue_request", &Executor::enqueueRequest, py::arg("request"))
|
||||
.def("enqueue_requests", &Executor::enqueueRequests, py::arg("requests"))
|
||||
.def("await_responses",
|
||||
py::overload_cast<std::optional<std::chrono::milliseconds> const&>(&Executor::awaitResponses),
|
||||
py::arg("timeout") = py::none())
|
||||
.def("await_responses",
|
||||
py::overload_cast<tle::IdType const&, std::optional<std::chrono::milliseconds> const&>(
|
||||
&Executor::awaitResponses),
|
||||
py::arg("id"), py::arg("timeout") = py::none())
|
||||
.def("await_responses",
|
||||
py::overload_cast<std::vector<tle::IdType> const&, std::optional<std::chrono::milliseconds> const&>(
|
||||
&Executor::awaitResponses),
|
||||
py::arg("ids"), py::arg("timeout") = py::none())
|
||||
.def("get_num_responses_ready", &Executor::getNumResponsesReady, py::arg("id") = py::none())
|
||||
.def("cancel_request", &Executor::cancelRequest, py::arg("id") = py::none())
|
||||
.def("get_latest_iteration_stats", &Executor::getLatestIterationStats)
|
||||
.def("get_latest_request_stats", &Executor::getLatestRequestStats)
|
||||
.def("get_latest_debug_tensors", &Executor::getLatestDebugTensors)
|
||||
.def("can_enqueue_requests", &Executor::canEnqueueRequests)
|
||||
.def("get_kv_cache_event_manager", &Executor::getKVCacheEventManager);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::executor
|
||||
@ -1,129 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
namespace tensorrt_llm::pybind::executor
|
||||
{
|
||||
|
||||
class Executor
|
||||
{
|
||||
public:
|
||||
Executor(
|
||||
std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig);
|
||||
|
||||
Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath,
|
||||
tle::ModelType modelType, tle::ExecutorConfig const& executorConfig);
|
||||
|
||||
Executor(pybind11::buffer engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType,
|
||||
tle::ExecutorConfig const& executorConfig, std::optional<pybind11::dict> managedWeights);
|
||||
|
||||
Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr,
|
||||
std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType,
|
||||
tle::ExecutorConfig const& executorConfig);
|
||||
|
||||
pybind11::object enter();
|
||||
void exit([[maybe_unused]] pybind11::handle type, [[maybe_unused]] pybind11::handle value,
|
||||
[[maybe_unused]] pybind11::handle traceback);
|
||||
void shutdown();
|
||||
|
||||
[[nodiscard]] tle::IdType enqueueRequest(tle::Request const& request)
|
||||
{
|
||||
return mExecutor->enqueueRequest(request);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<tle::IdType> enqueueRequests(std::vector<tle::Request> const& requests)
|
||||
{
|
||||
return mExecutor->enqueueRequests(requests);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<tle::Response> awaitResponses(
|
||||
std::optional<std::chrono::milliseconds> const& timeout = std::nullopt)
|
||||
{
|
||||
// Await responses blocks until a response is received. Release GIL so that it can be ran in a background
|
||||
// thread.
|
||||
pybind11::gil_scoped_release release;
|
||||
return mExecutor->awaitResponses(timeout);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<tle::Response> awaitResponses(
|
||||
tle::IdType const& requestId, std::optional<std::chrono::milliseconds> const& timeout = std::nullopt)
|
||||
{
|
||||
// Await responses blocks until a response is received. Release GIL so that it can be ran in a background
|
||||
// thread.
|
||||
pybind11::gil_scoped_release release;
|
||||
return mExecutor->awaitResponses(requestId, timeout);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::vector<tle::Response>> awaitResponses(std::vector<tle::IdType> const& requestIds,
|
||||
std::optional<std::chrono::milliseconds> const& timeout = std::nullopt)
|
||||
{
|
||||
// Await responses blocks until a response is received. Release GIL so that it can be ran in a background
|
||||
// thread.
|
||||
pybind11::gil_scoped_release release;
|
||||
return mExecutor->awaitResponses(requestIds, timeout);
|
||||
}
|
||||
|
||||
[[nodiscard]] tle::SizeType32 getNumResponsesReady(std::optional<tle::IdType> const& requestId = std::nullopt) const
|
||||
{
|
||||
return mExecutor->getNumResponsesReady(requestId);
|
||||
}
|
||||
|
||||
void cancelRequest(tle::IdType requestId)
|
||||
{
|
||||
mExecutor->cancelRequest(requestId);
|
||||
}
|
||||
|
||||
std::deque<tle::IterationStats> getLatestIterationStats()
|
||||
{
|
||||
return mExecutor->getLatestIterationStats();
|
||||
}
|
||||
|
||||
std::deque<tle::RequestStatsPerIteration> getLatestRequestStats()
|
||||
{
|
||||
return mExecutor->getLatestRequestStats();
|
||||
}
|
||||
|
||||
std::deque<tle::DebugTensorsPerIteration> getLatestDebugTensors()
|
||||
{
|
||||
return mExecutor->getLatestDebugTensors();
|
||||
}
|
||||
|
||||
[[nodiscard]] bool canEnqueueRequests() const
|
||||
{
|
||||
return mExecutor->canEnqueueRequests();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<std::shared_ptr<tle::KVCacheEventManager>> getKVCacheEventManager() const
|
||||
{
|
||||
return mExecutor->getKVCacheEventManager();
|
||||
}
|
||||
|
||||
static void initBindings(pybind11::module_& m);
|
||||
|
||||
private:
|
||||
std::unique_ptr<tle::Executor> mExecutor;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::executor
|
||||
@ -1,642 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "executorConfig.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
#include <optional>
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
using SizeType32 = tle::SizeType32;
|
||||
using RuntimeDefaults = tensorrt_llm::runtime::RuntimeDefaults;
|
||||
|
||||
namespace tensorrt_llm::pybind::executor
|
||||
{
|
||||
|
||||
void initConfigBindings(pybind11::module_& m)
|
||||
{
|
||||
py::enum_<tle::BatchingType>(m, "BatchingType")
|
||||
.value("STATIC", tle::BatchingType::kSTATIC)
|
||||
.value("INFLIGHT", tle::BatchingType::kINFLIGHT);
|
||||
|
||||
auto dynamicBatchConfigGetstate = [](tle::DynamicBatchConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.getEnableBatchSizeTuning(), self.getEnableMaxNumTokensTuning(),
|
||||
self.getDynamicBatchMovingAverageWindow(), self.getBatchSizeTable());
|
||||
};
|
||||
auto dynamicBatchConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::DynamicBatchConfig(state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<SizeType32>(),
|
||||
state[3].cast<std::vector<std::pair<SizeType32, SizeType32>>>());
|
||||
};
|
||||
py::class_<tle::DynamicBatchConfig>(m, "DynamicBatchConfig")
|
||||
.def(py::init<bool, bool, SizeType32>(), py::arg("enable_batch_size_tuning"),
|
||||
py::arg("enable_max_num_tokens_tuning"), py::arg("dynamic_batch_moving_average_window"))
|
||||
.def_property_readonly("enable_batch_size_tuning", &tle::DynamicBatchConfig::getEnableBatchSizeTuning)
|
||||
.def_property_readonly("enable_max_num_tokens_tuning", &tle::DynamicBatchConfig::getEnableMaxNumTokensTuning)
|
||||
.def_property_readonly(
|
||||
"dynamic_batch_moving_average_window", &tle::DynamicBatchConfig::getDynamicBatchMovingAverageWindow)
|
||||
.def(py::pickle(dynamicBatchConfigGetstate, dynamicBatchConfigSetstate));
|
||||
|
||||
auto schedulerConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 3)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::SchedulerConfig(state[0].cast<tle::CapacitySchedulerPolicy>(),
|
||||
state[1].cast<std::optional<tle::ContextChunkingPolicy>>(),
|
||||
state[2].cast<std::optional<tle::DynamicBatchConfig>>());
|
||||
};
|
||||
auto schedulerConfigGetstate = [](tle::SchedulerConfig const& self)
|
||||
{
|
||||
return py::make_tuple(
|
||||
self.getCapacitySchedulerPolicy(), self.getContextChunkingPolicy(), self.getDynamicBatchConfig());
|
||||
};
|
||||
py::class_<tle::SchedulerConfig>(m, "SchedulerConfig")
|
||||
.def(py::init<tle::CapacitySchedulerPolicy, std::optional<tle::ContextChunkingPolicy>,
|
||||
std::optional<tle::DynamicBatchConfig>>(),
|
||||
py::arg_v("capacity_scheduler_policy", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT,
|
||||
"CapacitySchedulerPolicy.GUARANTEED_NO_EVICT"),
|
||||
py::arg("context_chunking_policy") = py::none(), py::arg("dynamic_batch_config") = py::none())
|
||||
.def_property_readonly("capacity_scheduler_policy", &tle::SchedulerConfig::getCapacitySchedulerPolicy)
|
||||
.def_property_readonly("context_chunking_policy", &tle::SchedulerConfig::getContextChunkingPolicy)
|
||||
.def_property_readonly("dynamic_batch_config", &tle::SchedulerConfig::getDynamicBatchConfig)
|
||||
.def(py::pickle(schedulerConfigGetstate, schedulerConfigSetstate));
|
||||
|
||||
py::class_<RuntimeDefaults>(m, "RuntimeDefaults")
|
||||
.def(py::init<std::optional<std::vector<SizeType32>>, std::optional<SizeType32>>(),
|
||||
py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none())
|
||||
.def_readonly("max_attention_window", &RuntimeDefaults::maxAttentionWindowVec)
|
||||
.def_readonly("sink_token_length", &RuntimeDefaults::sinkTokenLength);
|
||||
|
||||
auto kvCacheConfigGetstate = [](tle::KvCacheConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(),
|
||||
self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(),
|
||||
self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(),
|
||||
self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(),
|
||||
self.getAttentionDpEventsGatherPeriodMs(), self.getMaxGpuTotalBytes());
|
||||
};
|
||||
auto kvCacheConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 15)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::KvCacheConfig(state[0].cast<bool>(), state[1].cast<std::optional<SizeType32>>(),
|
||||
state[2].cast<std::optional<std::vector<SizeType32>>>(), state[3].cast<std::optional<SizeType32>>(),
|
||||
state[4].cast<std::optional<float>>(), state[5].cast<std::optional<size_t>>(), state[6].cast<bool>(),
|
||||
state[7].cast<std::optional<float>>(), state[8].cast<std::optional<tle::RetentionPriority>>(),
|
||||
state[9].cast<size_t>(), state[10].cast<bool>(), state[11].cast<bool>(), state[12].cast<bool>(),
|
||||
state[13].cast<SizeType32>(), std::nullopt, state[14].cast<uint64_t>());
|
||||
};
|
||||
py::class_<tle::KvCacheConfig>(m, "KvCacheConfig")
|
||||
.def(py::init<bool, std::optional<SizeType32> const&, std::optional<std::vector<SizeType32>> const&,
|
||||
std::optional<SizeType32> const&, std::optional<float> const&, std::optional<size_t> const&, bool,
|
||||
std::optional<float> const&, std::optional<tle::RetentionPriority>, size_t const&, bool, bool, bool,
|
||||
SizeType32, std::optional<RuntimeDefaults> const&, uint64_t const&>(),
|
||||
py::arg("enable_block_reuse") = true, py::arg("max_tokens") = py::none(),
|
||||
py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none(),
|
||||
py::arg("free_gpu_memory_fraction") = py::none(), py::arg("host_cache_size") = py::none(),
|
||||
py::arg("onboard_blocks") = true, py::arg("cross_kv_cache_fraction") = py::none(),
|
||||
py::arg("secondary_offload_min_priority") = py::none(), py::arg("event_buffer_max_size") = 0, py::kw_only(),
|
||||
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, py::arg("use_uvm") = false,
|
||||
py::arg("attention_dp_events_gather_period_ms") = 5, py::arg("runtime_defaults") = py::none(),
|
||||
py::arg("max_gpu_total_bytes") = 0)
|
||||
.def_property(
|
||||
"enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse)
|
||||
.def_property("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens)
|
||||
.def_property("max_attention_window", &tle::KvCacheConfig::getMaxAttentionWindowVec,
|
||||
&tle::KvCacheConfig::setMaxAttentionWindowVec)
|
||||
.def_property(
|
||||
"sink_token_length", &tle::KvCacheConfig::getSinkTokenLength, &tle::KvCacheConfig::setSinkTokenLength)
|
||||
.def_property("free_gpu_memory_fraction", &tle::KvCacheConfig::getFreeGpuMemoryFraction,
|
||||
&tle::KvCacheConfig::setFreeGpuMemoryFraction)
|
||||
.def_property(
|
||||
"max_gpu_total_bytes", &tle::KvCacheConfig::getMaxGpuTotalBytes, &tle::KvCacheConfig::setMaxGpuTotalBytes)
|
||||
.def_property("host_cache_size", &tle::KvCacheConfig::getHostCacheSize, &tle::KvCacheConfig::setHostCacheSize)
|
||||
.def_property("onboard_blocks", &tle::KvCacheConfig::getOnboardBlocks, &tle::KvCacheConfig::setOnboardBlocks)
|
||||
.def_property("cross_kv_cache_fraction", &tle::KvCacheConfig::getCrossKvCacheFraction,
|
||||
&tle::KvCacheConfig::setCrossKvCacheFraction)
|
||||
.def_property("secondary_offload_min_priority", &tle::KvCacheConfig::getSecondaryOffloadMinPriority,
|
||||
&tle::KvCacheConfig::setSecondaryOffloadMinPriority)
|
||||
.def_property("event_buffer_max_size", &tle::KvCacheConfig::getEventBufferMaxSize,
|
||||
&tle::KvCacheConfig::setEventBufferMaxSize)
|
||||
.def_property("enable_partial_reuse", &tle::KvCacheConfig::getEnablePartialReuse,
|
||||
&tle::KvCacheConfig::setEnablePartialReuse)
|
||||
.def_property("copy_on_partial_reuse", &tle::KvCacheConfig::getCopyOnPartialReuse,
|
||||
&tle::KvCacheConfig::setCopyOnPartialReuse)
|
||||
.def_property("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm)
|
||||
.def_property("attention_dp_events_gather_period_ms", &tle::KvCacheConfig::getAttentionDpEventsGatherPeriodMs,
|
||||
&tle::KvCacheConfig::setAttentionDpEventsGatherPeriodMs)
|
||||
.def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults)
|
||||
.def(py::pickle(kvCacheConfigGetstate, kvCacheConfigSetstate));
|
||||
|
||||
py::class_<tle::OrchestratorConfig>(m, "OrchestratorConfig")
|
||||
.def(py::init<bool, std::string, std::shared_ptr<mpi::MpiComm>, bool>(), py::arg("is_orchestrator") = true,
|
||||
py::arg("worker_executable_path") = "", py::arg("orch_leader_comm") = nullptr,
|
||||
py::arg("spawn_processes") = true)
|
||||
.def_property(
|
||||
"is_orchestrator", &tle::OrchestratorConfig::getIsOrchestrator, &tle::OrchestratorConfig::setIsOrchestrator)
|
||||
.def_property("worker_executable_path", &tle::OrchestratorConfig::getWorkerExecutablePath,
|
||||
&tle::OrchestratorConfig::setWorkerExecutablePath)
|
||||
.def_property("orch_leader_comm", &tle::OrchestratorConfig::getOrchLeaderComm,
|
||||
&tle::OrchestratorConfig::setOrchLeaderComm)
|
||||
.def_property("spawn_processes", &tle::OrchestratorConfig::getSpawnProcesses,
|
||||
&tle::OrchestratorConfig::setSpawnProcesses);
|
||||
|
||||
auto parallelConfigGetstate = [](tle::ParallelConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.getCommunicationType(), self.getCommunicationMode(), self.getDeviceIds(),
|
||||
self.getParticipantIds(), self.getOrchestratorConfig(), self.getNumNodes());
|
||||
};
|
||||
auto parallelConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 6)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::ParallelConfig(state[0].cast<tle::CommunicationType>(), state[1].cast<tle::CommunicationMode>(),
|
||||
state[2].cast<std::optional<std::vector<SizeType32>>>(),
|
||||
state[3].cast<std::optional<std::vector<SizeType32>>>(),
|
||||
state[4].cast<std::optional<tle::OrchestratorConfig>>(), state[5].cast<std::optional<SizeType32>>());
|
||||
};
|
||||
py::class_<tle::ParallelConfig>(m, "ParallelConfig")
|
||||
.def(py::init<tle::CommunicationType, tle::CommunicationMode, std::optional<std::vector<SizeType32>> const&,
|
||||
std::optional<std::vector<SizeType32>> const&, std::optional<tle::OrchestratorConfig> const&,
|
||||
std::optional<SizeType32> const&>(),
|
||||
py::arg_v("communication_type", tle::CommunicationType::kMPI, "CommunicationType.MPI"),
|
||||
py::arg_v("communication_mode", tle::CommunicationMode::kLEADER, "CommunicationMode.LEADER"),
|
||||
py::arg("device_ids") = py::none(), py::arg("participant_ids") = py::none(),
|
||||
py::arg("orchestrator_config") = py::none(), py::arg("num_nodes") = py::none())
|
||||
.def_property("communication_type", &tle::ParallelConfig::getCommunicationType,
|
||||
&tle::ParallelConfig::setCommunicationType)
|
||||
.def_property("communication_mode", &tle::ParallelConfig::getCommunicationMode,
|
||||
&tle::ParallelConfig::setCommunicationMode)
|
||||
.def_property("device_ids", &tle::ParallelConfig::getDeviceIds, &tle::ParallelConfig::setDeviceIds)
|
||||
.def_property(
|
||||
"participant_ids", &tle::ParallelConfig::getParticipantIds, &tle::ParallelConfig::setParticipantIds)
|
||||
.def_property("orchestrator_config", &tle::ParallelConfig::getOrchestratorConfig,
|
||||
&tle::ParallelConfig::setOrchestratorConfig)
|
||||
.def_property("num_nodes", &tle::ParallelConfig::getNumNodes, &tle::ParallelConfig::setNumNodes)
|
||||
.def(py::pickle(parallelConfigGetstate, parallelConfigSetstate));
|
||||
|
||||
auto peftCacheConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 11)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::PeftCacheConfig(state[0].cast<SizeType32>(), state[1].cast<SizeType32>(),
|
||||
state[2].cast<SizeType32>(), state[3].cast<SizeType32>(), state[4].cast<SizeType32>(),
|
||||
state[5].cast<SizeType32>(), state[6].cast<SizeType32>(), state[7].cast<SizeType32>(),
|
||||
state[8].cast<SizeType32>(), state[9].cast<std::optional<float>>(),
|
||||
state[10].cast<std::optional<size_t>>());
|
||||
};
|
||||
auto peftCacheConfigGetstate = [](tle::PeftCacheConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.getNumHostModuleLayer(), self.getNumDeviceModuleLayer(),
|
||||
self.getOptimalAdapterSize(), self.getMaxAdapterSize(), self.getNumPutWorkers(), self.getNumEnsureWorkers(),
|
||||
self.getNumCopyStreams(), self.getMaxPagesPerBlockHost(), self.getMaxPagesPerBlockDevice(),
|
||||
self.getDeviceCachePercent(), self.getHostCacheSize());
|
||||
};
|
||||
py::class_<tle::PeftCacheConfig>(m, "PeftCacheConfig")
|
||||
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32,
|
||||
SizeType32, std::optional<float> const&, std::optional<size_t> const&,
|
||||
std::optional<std::string> const&>(),
|
||||
py::arg("num_host_module_layer") = 0, py::arg("num_device_module_layer") = 0,
|
||||
py::arg("optimal_adapter_size") = 8, py::arg("max_adapter_size") = 64, py::arg("num_put_workers") = 1,
|
||||
py::arg("num_ensure_workers") = 1, py::arg("num_copy_streams") = 1,
|
||||
py::arg("max_pages_per_block_host") = 24, py::arg("max_pages_per_block_device") = 8,
|
||||
py::arg("device_cache_percent") = py::none(), py::arg("host_cache_size") = py::none(),
|
||||
py::arg("lora_prefetch_dir") = py::none())
|
||||
.def_property_readonly("num_host_module_layer", &tle::PeftCacheConfig::getNumHostModuleLayer)
|
||||
.def_property_readonly("num_device_module_layer", &tle::PeftCacheConfig::getNumDeviceModuleLayer)
|
||||
.def_property_readonly("optimal_adapter_size", &tle::PeftCacheConfig::getOptimalAdapterSize)
|
||||
.def_property_readonly("max_adapter_size", &tle::PeftCacheConfig::getMaxAdapterSize)
|
||||
.def_property_readonly("num_put_workers", &tle::PeftCacheConfig::getNumPutWorkers)
|
||||
.def_property_readonly("num_ensure_workers", &tle::PeftCacheConfig::getNumEnsureWorkers)
|
||||
.def_property_readonly("num_copy_streams", &tle::PeftCacheConfig::getNumCopyStreams)
|
||||
.def_property_readonly("max_pages_per_block_host", &tle::PeftCacheConfig::getMaxPagesPerBlockHost)
|
||||
.def_property_readonly("max_pages_per_block_device", &tle::PeftCacheConfig::getMaxPagesPerBlockDevice)
|
||||
.def_property_readonly("device_cache_percent", &tle::PeftCacheConfig::getDeviceCachePercent)
|
||||
.def_property_readonly("host_cache_size", &tle::PeftCacheConfig::getHostCacheSize)
|
||||
.def_property_readonly("lora_prefetch_dir", &tle::PeftCacheConfig::getLoraPrefetchDir)
|
||||
.def(py::pickle(peftCacheConfigGetstate, peftCacheConfigSetstate));
|
||||
|
||||
auto decodingConfigGetstate = [](tle::DecodingConfig const& self)
|
||||
{
|
||||
return py::make_tuple(
|
||||
self.getDecodingMode(), self.getLookaheadDecodingConfig(), self.getMedusaChoices(), self.getEagleConfig());
|
||||
};
|
||||
auto decodingConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::DecodingConfig(state[0].cast<std::optional<tle::DecodingMode>>(), // DecodingMode
|
||||
state[1].cast<std::optional<tle::LookaheadDecodingConfig>>(), // LookaheadDecodingConfig
|
||||
state[2].cast<std::optional<tle::MedusaChoices>>(), // MedusaChoices
|
||||
state[3].cast<std::optional<tle::EagleConfig>>() // EagleConfig
|
||||
);
|
||||
};
|
||||
py::class_<tle::DecodingConfig>(m, "DecodingConfig")
|
||||
.def(py::init<std::optional<tle::DecodingMode>, std::optional<tle::LookaheadDecodingConfig>,
|
||||
std::optional<tle::MedusaChoices>, std::optional<tle::EagleConfig>>(),
|
||||
py::arg("decoding_mode") = py::none(), py::arg("lookahead_decoding_config") = py::none(),
|
||||
py::arg("medusa_choices") = py::none(), py::arg("eagle_config") = py::none())
|
||||
.def_property("decoding_mode", &tle::DecodingConfig::getDecodingMode, &tle::DecodingConfig::setDecodingMode)
|
||||
.def_property("lookahead_decoding_config", &tle::DecodingConfig::getLookaheadDecodingConfig,
|
||||
&tle::DecodingConfig::setLookaheadDecodingConfig)
|
||||
.def_property("medusa_choices", &tle::DecodingConfig::getMedusaChoices, &tle::DecodingConfig::setMedusaChoices)
|
||||
.def_property("eagle_config", &tle::DecodingConfig::getEagleConfig, &tle::DecodingConfig::setEagleConfig)
|
||||
.def(py::pickle(decodingConfigGetstate, decodingConfigSetstate));
|
||||
|
||||
auto debugConfigGetstate = [](tle::DebugConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.getDebugInputTensors(), self.getDebugOutputTensors(), self.getDebugTensorNames(),
|
||||
self.getDebugTensorsMaxIterations());
|
||||
};
|
||||
auto debugConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::DebugConfig(state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<std::vector<std::string>>(),
|
||||
state[3].cast<SizeType32>());
|
||||
};
|
||||
py::class_<tle::DebugConfig>(m, "DebugConfig")
|
||||
.def(py::init<bool, bool, std::vector<std::string>, SizeType32>(), py::arg("debug_input_tensors") = false,
|
||||
py::arg("debug_output_tensors") = false, py::arg("debug_tensor_names") = py::none(),
|
||||
py::arg("debug_tensors_max_iterations") = false)
|
||||
.def_property(
|
||||
"debug_input_tensors", &tle::DebugConfig::getDebugInputTensors, &tle::DebugConfig::setDebugInputTensors)
|
||||
.def_property(
|
||||
"debug_output_tensors", &tle::DebugConfig::getDebugOutputTensors, &tle::DebugConfig::setDebugOutputTensors)
|
||||
.def_property(
|
||||
"debug_tensor_names", &tle::DebugConfig::getDebugTensorNames, &tle::DebugConfig::setDebugTensorNames)
|
||||
.def_property("debug_tensors_max_iterations", &tle::DebugConfig::getDebugTensorsMaxIterations,
|
||||
&tle::DebugConfig::setDebugTensorsMaxIterations)
|
||||
.def(py::pickle(debugConfigGetstate, debugConfigSetstate));
|
||||
|
||||
auto logitsPostProcessorConfigGetstate = [](tle::LogitsPostProcessorConfig const& self)
|
||||
{ return py::make_tuple(self.getProcessorMap(), self.getProcessorBatched(), self.getReplicate()); };
|
||||
|
||||
auto logitsPostProcessorConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 3)
|
||||
{
|
||||
throw std::runtime_error("Invalid LogitsPostProcessorConfig state!");
|
||||
}
|
||||
return tle::LogitsPostProcessorConfig(state[0].cast<std::optional<tle::LogitsPostProcessorMap>>(),
|
||||
state[1].cast<std::optional<tle::LogitsPostProcessorBatched>>(), state[2].cast<bool>());
|
||||
};
|
||||
|
||||
py::class_<tle::LogitsPostProcessorConfig>(m, "LogitsPostProcessorConfig")
|
||||
.def(py::init<std::optional<tle::LogitsPostProcessorMap>, std::optional<tle::LogitsPostProcessorBatched>,
|
||||
bool>(),
|
||||
py::arg("processor_map") = py::none(), py::arg("processor_batched") = py::none(),
|
||||
py::arg("replicate") = true)
|
||||
.def_property("processor_map", &tle::LogitsPostProcessorConfig::getProcessorMap,
|
||||
&tle::LogitsPostProcessorConfig::setProcessorMap)
|
||||
.def_property("processor_batched", &tle::LogitsPostProcessorConfig::getProcessorBatched,
|
||||
&tle::LogitsPostProcessorConfig::setProcessorBatched)
|
||||
.def_property(
|
||||
"replicate", &tle::LogitsPostProcessorConfig::getReplicate, &tle::LogitsPostProcessorConfig::setReplicate)
|
||||
.def(py::pickle(logitsPostProcessorConfigGetstate, logitsPostProcessorConfigSetstate));
|
||||
|
||||
auto extendedRuntimePerfKnobConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
{
|
||||
throw std::runtime_error("Invalid extendedRuntimePerfKnobConfig state!");
|
||||
}
|
||||
return tle::ExtendedRuntimePerfKnobConfig(
|
||||
state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<bool>(), state[3].cast<SizeType32>());
|
||||
};
|
||||
auto extendedRuntimePerfKnobConfigGetstate = [](tle::ExtendedRuntimePerfKnobConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc(), self.getCudaGraphMode(),
|
||||
self.getCudaGraphCacheSize());
|
||||
};
|
||||
py::class_<tle::ExtendedRuntimePerfKnobConfig>(m, "ExtendedRuntimePerfKnobConfig")
|
||||
.def(
|
||||
py::init<bool, bool>(), py::arg("multi_block_mode") = true, py::arg("enable_context_fmha_fp32_acc") = false)
|
||||
.def_property("multi_block_mode", &tle::ExtendedRuntimePerfKnobConfig::getMultiBlockMode,
|
||||
&tle::ExtendedRuntimePerfKnobConfig::setMultiBlockMode)
|
||||
.def_property("enable_context_fmha_fp32_acc", &tle::ExtendedRuntimePerfKnobConfig::getEnableContextFMHAFP32Acc,
|
||||
&tle::ExtendedRuntimePerfKnobConfig::setEnableContextFMHAFP32Acc)
|
||||
.def_property("cuda_graph_mode", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphMode,
|
||||
&tle::ExtendedRuntimePerfKnobConfig::setCudaGraphMode)
|
||||
.def_property("cuda_graph_cache_size", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphCacheSize,
|
||||
&tle::ExtendedRuntimePerfKnobConfig::setCudaGraphCacheSize)
|
||||
.def(py::pickle(extendedRuntimePerfKnobConfigGetstate, extendedRuntimePerfKnobConfigSetstate));
|
||||
|
||||
auto SpeculativeDecodingConfigGetState
|
||||
= [](tle::SpeculativeDecodingConfig const& self) { return py::make_tuple(self.fastLogits); };
|
||||
auto SpeculativeDecodingConfigSetState = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 1)
|
||||
{
|
||||
throw std::runtime_error("Invalid SpeculativeDecodingConfig state!");
|
||||
}
|
||||
return tle::SpeculativeDecodingConfig(state[0].cast<bool>());
|
||||
};
|
||||
py::class_<tle::SpeculativeDecodingConfig>(m, "SpeculativeDecodingConfig")
|
||||
.def(py::init<bool>(), py::arg("fast_logits") = false)
|
||||
.def_readwrite("fast_logits", &tle::SpeculativeDecodingConfig::fastLogits)
|
||||
.def(py::pickle(SpeculativeDecodingConfigGetState, SpeculativeDecodingConfigSetState));
|
||||
|
||||
// Guided decoding config
|
||||
auto pyGuidedDecodingConfig = py::class_<tle::GuidedDecodingConfig>(m, "GuidedDecodingConfig");
|
||||
|
||||
py::enum_<tle::GuidedDecodingConfig::GuidedDecodingBackend>(pyGuidedDecodingConfig, "GuidedDecodingBackend")
|
||||
.value("XGRAMMAR", tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR)
|
||||
.value("LLGUIDANCE", tle::GuidedDecodingConfig::GuidedDecodingBackend::kLLGUIDANCE);
|
||||
|
||||
auto guidedDecodingConfigGetstate = [](tle::GuidedDecodingConfig const& self) {
|
||||
return py::make_tuple(
|
||||
self.getBackend(), self.getEncodedVocab(), self.getTokenizerStr(), self.getStopTokenIds());
|
||||
};
|
||||
auto guidedDecodingConfigSetstate = [](py::tuple state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
{
|
||||
throw std::runtime_error("Invalid GuidedDecodingConfig state!");
|
||||
}
|
||||
return tle::GuidedDecodingConfig(state[0].cast<tle::GuidedDecodingConfig::GuidedDecodingBackend>(),
|
||||
state[1].cast<std::optional<std::vector<std::string>>>(), state[2].cast<std::optional<std::string>>(),
|
||||
state[3].cast<std::optional<std::vector<tle::TokenIdType>>>());
|
||||
};
|
||||
|
||||
pyGuidedDecodingConfig
|
||||
.def(py::init<tle::GuidedDecodingConfig::GuidedDecodingBackend, std::optional<std::vector<std::string>>,
|
||||
std::optional<std::string>, std::optional<std::vector<tle::TokenIdType>>>(),
|
||||
py::arg("backend"), py::arg("encoded_vocab") = py::none(), py::arg("tokenizer_str") = py::none(),
|
||||
py::arg("stop_token_ids") = py::none())
|
||||
.def_property("backend", &tle::GuidedDecodingConfig::getBackend, &tle::GuidedDecodingConfig::setBackend)
|
||||
.def_property(
|
||||
"encoded_vocab", &tle::GuidedDecodingConfig::getEncodedVocab, &tle::GuidedDecodingConfig::setEncodedVocab)
|
||||
.def_property(
|
||||
"tokenizer_str", &tle::GuidedDecodingConfig::getTokenizerStr, &tle::GuidedDecodingConfig::setTokenizerStr)
|
||||
.def_property(
|
||||
"stop_token_ids", &tle::GuidedDecodingConfig::getStopTokenIds, &tle::GuidedDecodingConfig::setStopTokenIds)
|
||||
.def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate));
|
||||
|
||||
auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
|
||||
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); };
|
||||
auto cacheTransceiverConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 3)
|
||||
{
|
||||
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
|
||||
}
|
||||
return tle::CacheTransceiverConfig(state[0].cast<tle::CacheTransceiverConfig::BackendType>(),
|
||||
state[1].cast<std::optional<size_t>>(), state[2].cast<std::optional<int>>());
|
||||
};
|
||||
|
||||
py::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
|
||||
.value("DEFAULT", tle::CacheTransceiverConfig::BackendType::DEFAULT)
|
||||
.value("MPI", tle::CacheTransceiverConfig::BackendType::MPI)
|
||||
.value("UCX", tle::CacheTransceiverConfig::BackendType::UCX)
|
||||
.value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL)
|
||||
.value("MOONCAKE", tle::CacheTransceiverConfig::BackendType::MOONCAKE)
|
||||
.def("from_string",
|
||||
[](std::string const& str)
|
||||
{
|
||||
if (str == "DEFAULT" || str == "default")
|
||||
return tle::CacheTransceiverConfig::BackendType::DEFAULT;
|
||||
if (str == "MPI" || str == "mpi")
|
||||
return tle::CacheTransceiverConfig::BackendType::MPI;
|
||||
if (str == "UCX" || str == "ucx")
|
||||
return tle::CacheTransceiverConfig::BackendType::UCX;
|
||||
if (str == "NIXL" || str == "nixl")
|
||||
return tle::CacheTransceiverConfig::BackendType::NIXL;
|
||||
if (str == "MOONCAKE" || str == "mooncake")
|
||||
return tle::CacheTransceiverConfig::BackendType::MOONCAKE;
|
||||
throw std::runtime_error("Invalid backend type: " + str);
|
||||
});
|
||||
|
||||
py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
|
||||
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
|
||||
std::optional<int>, std::optional<int>>(),
|
||||
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt,
|
||||
py::arg("kv_transfer_timeout_ms") = std::nullopt,
|
||||
py::arg("kv_transfer_sender_future_timeout_ms") = std::nullopt)
|
||||
.def_property(
|
||||
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
|
||||
.def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
|
||||
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
|
||||
.def_property("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
|
||||
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
|
||||
.def_property("kv_transfer_sender_future_timeout_ms",
|
||||
&tle::CacheTransceiverConfig::getKvTransferSenderFutureTimeoutMs,
|
||||
&tle::CacheTransceiverConfig::setKvTransferSenderFutureTimeoutMs)
|
||||
.def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate));
|
||||
|
||||
auto executorConfigGetState = [](py::object const& self)
|
||||
{
|
||||
auto& c = self.cast<tle::ExecutorConfig&>();
|
||||
// Return a tuple containing C++ data and the Python __dict__
|
||||
auto cpp_states = py::make_tuple(c.getMaxBeamWidth(), c.getSchedulerConfig(), c.getKvCacheConfig(),
|
||||
c.getEnableChunkedContext(), c.getNormalizeLogProbs(), c.getIterStatsMaxIterations(),
|
||||
c.getRequestStatsMaxIterations(), c.getBatchingType(), c.getMaxBatchSize(), c.getMaxNumTokens(),
|
||||
c.getParallelConfig(), c.getPeftCacheConfig(), c.getLogitsPostProcessorConfig(), c.getDecodingConfig(),
|
||||
c.getUseGpuDirectStorage(), c.getGpuWeightsPercent(), c.getMaxQueueSize(),
|
||||
c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(),
|
||||
c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(),
|
||||
c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(),
|
||||
c.getPromptTableOffloading(), c.getEnableTrtOverlap(), c.getFailFastOnAttentionWindowTooLarge());
|
||||
auto pickle_tuple = py::make_tuple(cpp_states, py::getattr(self, "__dict__"));
|
||||
return pickle_tuple;
|
||||
};
|
||||
auto executorConfigSetState = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
|
||||
// Restore C++ data
|
||||
auto cpp_states = state[0].cast<py::tuple>();
|
||||
if (cpp_states.size() != 29)
|
||||
{
|
||||
throw std::runtime_error("Invalid cpp_states!");
|
||||
}
|
||||
|
||||
auto ec = tle::ExecutorConfig( //
|
||||
cpp_states[0].cast<SizeType32>(), // MaxBeamWidth
|
||||
cpp_states[1].cast<tle::SchedulerConfig>(), // SchedulerConfig
|
||||
cpp_states[2].cast<tle::KvCacheConfig>(), // KvCacheConfig
|
||||
cpp_states[3].cast<bool>(), // EnableChunkedContext
|
||||
cpp_states[4].cast<bool>(), // NormalizeLogProbs
|
||||
cpp_states[5].cast<SizeType32>(), // IterStatsMaxIterations
|
||||
cpp_states[6].cast<SizeType32>(), // RequestStatsMaxIterations
|
||||
cpp_states[7].cast<tle::BatchingType>(), // BatchingType
|
||||
cpp_states[8].cast<std::optional<SizeType32>>(), // MaxBatchSize
|
||||
cpp_states[9].cast<std::optional<SizeType32>>(), // MaxNumTokens
|
||||
cpp_states[10].cast<std::optional<tle::ParallelConfig>>(), // ParallelConfig
|
||||
cpp_states[11].cast<std::optional<tle::PeftCacheConfig>>(), // PeftCacheConfig
|
||||
cpp_states[12].cast<std::optional<tle::LogitsPostProcessorConfig>>(), // LogitsPostProcessorConfig
|
||||
cpp_states[13].cast<std::optional<tle::DecodingConfig>>(), // DecodingConfig
|
||||
cpp_states[14].cast<bool>(), // UseGpuDirectStorage
|
||||
cpp_states[15].cast<float>(), // GpuWeightsPercent
|
||||
cpp_states[16].cast<std::optional<SizeType32>>(), // MaxQueueSize
|
||||
cpp_states[17].cast<tle::ExtendedRuntimePerfKnobConfig>(), // ExtendedRuntimePerfKnobConfig
|
||||
cpp_states[18].cast<std::optional<tle::DebugConfig>>(), // DebugConfig
|
||||
cpp_states[19].cast<SizeType32>(), // RecvPollPeriodMs
|
||||
cpp_states[20].cast<uint64_t>(), // MaxSeqIdleMicroseconds
|
||||
cpp_states[21].cast<std::optional<tle::SpeculativeDecodingConfig>>(), // SpecDecConfig
|
||||
cpp_states[22].cast<std::optional<tle::GuidedDecodingConfig>>(), // GuidedDecodingConfig
|
||||
cpp_states[23].cast<std::optional<std::vector<tle::AdditionalModelOutput>>>(), // AdditionalModelOutputs
|
||||
cpp_states[24].cast<std::optional<tle::CacheTransceiverConfig>>(), // CacheTransceiverConfig
|
||||
cpp_states[25].cast<bool>(), // GatherGenerationLogits
|
||||
cpp_states[26].cast<bool>(), // PromptTableOffloading
|
||||
cpp_states[27].cast<bool>(), // EnableTrtOverlap
|
||||
cpp_states[28].cast<bool>() // FailFastOnAttentionWindowTooLarge
|
||||
);
|
||||
|
||||
auto py_state = state[1].cast<py::dict>();
|
||||
|
||||
return std::make_pair(ec, py_state);
|
||||
};
|
||||
|
||||
py::class_<tle::ExecutorConfig>(m, "ExecutorConfig", pybind11::dynamic_attr())
|
||||
.def(py::init< //
|
||||
SizeType32, // MaxBeamWidth
|
||||
tle::SchedulerConfig const&, // SchedulerConfig
|
||||
tle::KvCacheConfig const&, // KvCacheConfig
|
||||
bool, // EnableChunkedContext
|
||||
bool, // NormalizeLogProbs
|
||||
SizeType32, // IterStatsMaxIterations
|
||||
SizeType32, // RequestStatsMaxIterations
|
||||
tle::BatchingType, // BatchingType
|
||||
std::optional<SizeType32>, // MaxBatchSize
|
||||
std::optional<SizeType32>, // MaxNumTokens
|
||||
std::optional<tle::ParallelConfig>, // ParallelConfig
|
||||
tle::PeftCacheConfig const&, // PeftCacheConfig
|
||||
std::optional<tle::LogitsPostProcessorConfig>, // LogitsPostProcessorConfig
|
||||
std::optional<tle::DecodingConfig>, // DecodingConfig
|
||||
bool, // UseGpuDirectStorage
|
||||
float, // GpuWeightsPercent
|
||||
std::optional<SizeType32>, // MaxQueueSize
|
||||
tle::ExtendedRuntimePerfKnobConfig const&, // ExtendedRuntimePerfKnobConfig
|
||||
std::optional<tle::DebugConfig>, // DebugConfig
|
||||
SizeType32, // RecvPollPeriodMs
|
||||
uint64_t, // MaxSeqIdleMicroseconds
|
||||
std::optional<tle::SpeculativeDecodingConfig>, // SpecDecConfig
|
||||
std::optional<tle::GuidedDecodingConfig>, // GuidedDecodingConfig
|
||||
std::optional<std::vector<tle::AdditionalModelOutput>>, // AdditionalModelOutputs
|
||||
std::optional<tle::CacheTransceiverConfig>, // CacheTransceiverConfig
|
||||
bool, // GatherGenerationLogits
|
||||
bool, // PromptTableOffloading
|
||||
bool, // EnableTrtOverlap
|
||||
bool // FailFastOnAttentionWindowTooLarge
|
||||
>(),
|
||||
py::arg("max_beam_width") = 1, py::arg_v("scheduler_config", tle::SchedulerConfig(), "SchedulerConfig()"),
|
||||
py::arg_v("kv_cache_config", tle::KvCacheConfig(), "KvCacheConfig()"),
|
||||
py::arg("enable_chunked_context") = false, py::arg("normalize_log_probs") = true,
|
||||
py::arg("iter_stats_max_iterations") = tle::ExecutorConfig::kDefaultIterStatsMaxIterations,
|
||||
py::arg("request_stats_max_iterations") = tle::ExecutorConfig::kDefaultRequestStatsMaxIterations,
|
||||
py::arg_v("batching_type", tle::BatchingType::kINFLIGHT, "BatchingType.INFLIGHT"),
|
||||
py::arg("max_batch_size") = py::none(), py::arg("max_num_tokens") = py::none(),
|
||||
py::arg("parallel_config") = py::none(),
|
||||
py::arg_v("peft_cache_config", tle::PeftCacheConfig(), "PeftCacheConfig()"),
|
||||
py::arg("logits_post_processor_config") = py::none(), py::arg("decoding_config") = py::none(),
|
||||
py::arg("use_gpu_direct_storage") = false, py::arg("gpu_weights_percent") = 1.0,
|
||||
py::arg("max_queue_size") = py::none(),
|
||||
py::arg_v("extended_runtime_perf_knob_config", tle::ExtendedRuntimePerfKnobConfig(),
|
||||
"ExtendedRuntimePerfKnobConfig()"),
|
||||
py::arg("debug_config") = py::none(), py::arg("recv_poll_period_ms") = 0,
|
||||
py::arg("max_seq_idle_microseconds") = tle::ExecutorConfig::kDefaultMaxSeqIdleMicroseconds,
|
||||
py::arg("spec_dec_config") = py::none(), py::arg("guided_decoding_config") = py::none(),
|
||||
py::arg("additional_model_outputs") = py::none(), py::arg("cache_transceiver_config") = py::none(),
|
||||
py::arg("gather_generation_logits") = false, py::arg("mm_embedding_offloading") = false,
|
||||
py::arg("enable_trt_overlap") = false, py::arg("fail_fast_on_attention_window_too_large") = false)
|
||||
.def_property("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth)
|
||||
.def_property("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize)
|
||||
.def_property("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens)
|
||||
.def_property(
|
||||
"scheduler_config", &tle::ExecutorConfig::getSchedulerConfigRef, &tle::ExecutorConfig::setSchedulerConfig)
|
||||
.def_property(
|
||||
"kv_cache_config", &tle::ExecutorConfig::getKvCacheConfigRef, &tle::ExecutorConfig::setKvCacheConfig)
|
||||
.def_property("enable_chunked_context", &tle::ExecutorConfig::getEnableChunkedContext,
|
||||
&tle::ExecutorConfig::setEnableChunkedContext)
|
||||
.def_property("normalize_log_probs", &tle::ExecutorConfig::getNormalizeLogProbs,
|
||||
&tle::ExecutorConfig::setNormalizeLogProbs)
|
||||
.def_property("iter_stats_max_iterations", &tle::ExecutorConfig::getIterStatsMaxIterations,
|
||||
&tle::ExecutorConfig::setIterStatsMaxIterations)
|
||||
.def_property("request_stats_max_iterations", &tle::ExecutorConfig::getRequestStatsMaxIterations,
|
||||
&tle::ExecutorConfig::setRequestStatsMaxIterations)
|
||||
.def_property("batching_type", &tle::ExecutorConfig::getBatchingType, &tle::ExecutorConfig::setBatchingType)
|
||||
.def_property(
|
||||
"parallel_config", &tle::ExecutorConfig::getParallelConfig, &tle::ExecutorConfig::setParallelConfig)
|
||||
.def_property(
|
||||
"peft_cache_config", &tle::ExecutorConfig::getPeftCacheConfig, &tle::ExecutorConfig::setPeftCacheConfig)
|
||||
.def_property("logits_post_processor_config", &tle::ExecutorConfig::getLogitsPostProcessorConfig,
|
||||
&tle::ExecutorConfig::setLogitsPostProcessorConfig)
|
||||
.def_property(
|
||||
"decoding_config", &tle::ExecutorConfig::getDecodingConfig, &tle::ExecutorConfig::setDecodingConfig)
|
||||
.def_property("use_gpu_direct_storage", &tle::ExecutorConfig::getUseGpuDirectStorage,
|
||||
&tle::ExecutorConfig::setUseGpuDirectStorage)
|
||||
.def_property("gpu_weights_percent", &tle::ExecutorConfig::getGpuWeightsPercent,
|
||||
&tle::ExecutorConfig::setGpuWeightsPercent)
|
||||
.def_property("max_queue_size", &tle::ExecutorConfig::getMaxQueueSize, &tle::ExecutorConfig::setMaxQueueSize)
|
||||
.def_property("extended_runtime_perf_knob_config", &tle::ExecutorConfig::getExtendedRuntimePerfKnobConfig,
|
||||
&tle::ExecutorConfig::setExtendedRuntimePerfKnobConfig)
|
||||
.def_property("debug_config", &tle::ExecutorConfig::getDebugConfig, &tle::ExecutorConfig::setDebugConfig)
|
||||
.def_property(
|
||||
"recv_poll_period_ms", &tle::ExecutorConfig::getRecvPollPeriodMs, &tle::ExecutorConfig::setRecvPollPeriodMs)
|
||||
.def_property("max_seq_idle_microseconds", &tle::ExecutorConfig::getMaxSeqIdleMicroseconds,
|
||||
&tle::ExecutorConfig::setMaxSeqIdleMicroseconds)
|
||||
.def_property("spec_dec_config", &tle::ExecutorConfig::getSpecDecConfig, &tle::ExecutorConfig::setSpecDecConfig)
|
||||
.def_property("guided_decoding_config", &tle::ExecutorConfig::getGuidedDecodingConfig,
|
||||
&tle::ExecutorConfig::setGuidedDecodingConfig)
|
||||
.def_property("additional_model_outputs", &tle::ExecutorConfig::getAdditionalModelOutputs,
|
||||
&tle::ExecutorConfig::setAdditionalModelOutputs)
|
||||
.def_property("cache_transceiver_config", &tle::ExecutorConfig::getCacheTransceiverConfig,
|
||||
&tle::ExecutorConfig::setCacheTransceiverConfig)
|
||||
.def_property("gather_generation_logits", &tle::ExecutorConfig::getGatherGenerationLogits,
|
||||
&tle::ExecutorConfig::setGatherGenerationLogits)
|
||||
.def_property("mm_embedding_offloading", &tle::ExecutorConfig::getPromptTableOffloading,
|
||||
&tle::ExecutorConfig::setPromptTableOffloading)
|
||||
.def_property(
|
||||
"enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap)
|
||||
.def_property("fail_fast_on_attention_window_too_large",
|
||||
&tle::ExecutorConfig::getFailFastOnAttentionWindowTooLarge,
|
||||
&tle::ExecutorConfig::setFailFastOnAttentionWindowTooLarge)
|
||||
.def(py::pickle(executorConfigGetState, executorConfigSetState));
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::executor
|
||||
@ -1,29 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::executor
|
||||
{
|
||||
|
||||
// Register bindings for executor API.
|
||||
void initConfigBindings(pybind11::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::executor
|
||||
@ -1,927 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "request.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/executor/serializeUtils.h"
|
||||
#include "tensorrt_llm/executor/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <sstream>
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
using Tensor = tle::Tensor;
|
||||
using SizeType32 = tle::SizeType32;
|
||||
using FloatType = tle::FloatType;
|
||||
using VecTokens = tle::VecTokens;
|
||||
using IdType = tle::IdType;
|
||||
using VecTokenExtraIds = tle::VecTokenExtraIds;
|
||||
|
||||
namespace tensorrt_llm::pybind::executor
|
||||
{
|
||||
|
||||
void initRequestBindings(pybind11::module_& m)
|
||||
{
|
||||
py::enum_<tle::RequestType>(m, "RequestType")
|
||||
.value("REQUEST_TYPE_CONTEXT_AND_GENERATION", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION)
|
||||
.value("REQUEST_TYPE_CONTEXT_ONLY", tle::RequestType::REQUEST_TYPE_CONTEXT_ONLY)
|
||||
.value("REQUEST_TYPE_GENERATION_ONLY", tle::RequestType::REQUEST_TYPE_GENERATION_ONLY);
|
||||
|
||||
py::enum_<tle::FinishReason>(m, "FinishReason")
|
||||
.value("NOT_FINISHED", tle::FinishReason::kNOT_FINISHED)
|
||||
.value("END_ID", tle::FinishReason::kEND_ID)
|
||||
.value("STOP_WORDS", tle::FinishReason::kSTOP_WORDS)
|
||||
.value("LENGTH", tle::FinishReason::kLENGTH)
|
||||
.value("TIMED_OUT", tle::FinishReason::kTIMED_OUT)
|
||||
.value("CANCELLED", tle::FinishReason::kCANCELLED);
|
||||
|
||||
py::enum_<tle::KvCacheTransferMode>(m, "KvCacheTransferMode")
|
||||
.value("DRAM", tle::KvCacheTransferMode::DRAM)
|
||||
.value("GDS", tle::KvCacheTransferMode::GDS)
|
||||
.value("POSIX_DEBUG_FALLBACK", tle::KvCacheTransferMode::POSIX_DEBUG_FALLBACK);
|
||||
|
||||
auto samplingConfigGetstate = [](tle::SamplingConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(),
|
||||
self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(),
|
||||
self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(),
|
||||
self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(),
|
||||
self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
|
||||
};
|
||||
auto samplingConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 20)
|
||||
{
|
||||
throw std::runtime_error("Invalid SamplingConfig state!");
|
||||
}
|
||||
return tle::SamplingConfig(state[0].cast<SizeType32>(), // BeamWidth
|
||||
state[1].cast<std::optional<SizeType32>>(), // TopK
|
||||
state[2].cast<std::optional<FloatType>>(), // TopP
|
||||
state[3].cast<std::optional<FloatType>>(), // TopPMin
|
||||
state[4].cast<std::optional<tle::TokenIdType>>(), // TopPResetIds
|
||||
state[5].cast<std::optional<FloatType>>(), // TopPDecay
|
||||
state[6].cast<std::optional<tle::RandomSeedType>>(), // Seed
|
||||
state[7].cast<std::optional<FloatType>>(), // Temperature
|
||||
state[8].cast<std::optional<SizeType32>>(), // MinTokens
|
||||
state[9].cast<std::optional<FloatType>>(), // BeamSearchDiversityRate
|
||||
state[10].cast<std::optional<FloatType>>(), // RepetitionPenalty
|
||||
state[11].cast<std::optional<FloatType>>(), // PresencePenalty
|
||||
state[12].cast<std::optional<FloatType>>(), // FrequencyPenalty
|
||||
state[13].cast<std::optional<SizeType32>>(), // PromptIgnoreLength
|
||||
state[14].cast<std::optional<FloatType>>(), // LengthPenalty
|
||||
state[15].cast<std::optional<SizeType32>>(), // EarlyStopping
|
||||
state[16].cast<std::optional<SizeType32>>(), // NoRepeatNgramSize
|
||||
state[17].cast<std::optional<SizeType32>>(), // NumReturnSequences
|
||||
state[18].cast<std::optional<FloatType>>(), // MinP
|
||||
state[19].cast<std::optional<std::vector<SizeType32>>>() // BeamWidthArray
|
||||
);
|
||||
};
|
||||
py::class_<tle::SamplingConfig>(m, "SamplingConfig")
|
||||
.def(py::init<tle::SizeType32,
|
||||
std::optional<tle::SizeType32> const&, // beamWidth
|
||||
std::optional<tle::FloatType> const&, // topP
|
||||
std::optional<tle::FloatType> const&, // topPMin
|
||||
std::optional<tle::TokenIdType> const&, // topPResetIds
|
||||
std::optional<tle::FloatType> const&, // topPDecay
|
||||
std::optional<tle::RandomSeedType> const&, // seed
|
||||
std::optional<tle::FloatType> const&, // temperature
|
||||
std::optional<tle::SizeType32> const&, // minTokens
|
||||
std::optional<tle::FloatType> const&, // beamSearchDiversityRate
|
||||
std::optional<tle::FloatType> const&, // repetitionPenalty
|
||||
std::optional<tle::FloatType> const&, // presencePenalty
|
||||
std::optional<tle::FloatType> const&, // frequencyPenalty
|
||||
std::optional<tle::SizeType32> const&, // promptIgnoreLength
|
||||
std::optional<tle::FloatType> const&, // lengthPenalty
|
||||
std::optional<tle::SizeType32> const&, // earlyStopping
|
||||
std::optional<tle::SizeType32> const&, // noRepeatNgramSize
|
||||
std::optional<tle::SizeType32> const&, // numReturnSequences
|
||||
std::optional<tle::FloatType> const&, // minP
|
||||
std::optional<std::vector<tle::SizeType32>> const& // beamWidthArray
|
||||
>(),
|
||||
// clang-format off
|
||||
py::arg("beam_width") = 1,
|
||||
py::kw_only(),
|
||||
py::arg("top_k") = py::none(),
|
||||
py::arg("top_p") = py::none(),
|
||||
py::arg("top_p_min") = py::none(),
|
||||
py::arg("top_p_reset_ids") = py::none(),
|
||||
py::arg("top_p_decay") = py::none(),
|
||||
py::arg("seed") = py::none(),
|
||||
py::arg("temperature") = py::none(),
|
||||
py::arg("min_tokens") = py::none(),
|
||||
py::arg("beam_search_diversity_rate") = py::none(),
|
||||
py::arg("repetition_penalty") = py::none(),
|
||||
py::arg("presence_penalty") = py::none(),
|
||||
py::arg("frequency_penalty") = py::none(),
|
||||
py::arg("prompt_ignore_length") = py::none(),
|
||||
py::arg("length_penalty") = py::none(),
|
||||
py::arg("early_stopping") = py::none(),
|
||||
py::arg("no_repeat_ngram_size") = py::none(),
|
||||
py::arg("num_return_sequences") = py::none(),
|
||||
py::arg("min_p") = py::none(),
|
||||
py::arg("beam_width_array") = py::none()) // clang-format on
|
||||
.def_property("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth)
|
||||
.def_property("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK)
|
||||
.def_property("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP)
|
||||
.def_property("top_p_min", &tle::SamplingConfig::getTopPMin, &tle::SamplingConfig::setTopPMin)
|
||||
.def_property("top_p_reset_ids", &tle::SamplingConfig::getTopPResetIds, &tle::SamplingConfig::setTopPResetIds)
|
||||
.def_property("top_p_decay", &tle::SamplingConfig::getTopPDecay, &tle::SamplingConfig::setTopPDecay)
|
||||
.def_property("seed", &tle::SamplingConfig::getSeed, &tle::SamplingConfig::setSeed)
|
||||
.def_property("temperature", &tle::SamplingConfig::getTemperature, &tle::SamplingConfig::setTemperature)
|
||||
.def_property("min_tokens", &tle::SamplingConfig::getMinTokens, &tle::SamplingConfig::setMinTokens)
|
||||
.def_property("beam_search_diversity_rate", &tle::SamplingConfig::getBeamSearchDiversityRate,
|
||||
&tle::SamplingConfig::setBeamSearchDiversityRate)
|
||||
.def_property("repetition_penalty", &tle::SamplingConfig::getRepetitionPenalty,
|
||||
&tle::SamplingConfig::setRepetitionPenalty)
|
||||
.def_property("presence_penalty", &tle::SamplingConfig::getPresencePenalty,
|
||||
[](tle::SamplingConfig& self, std::optional<FloatType> v) { self.setPresencePenalty(v); })
|
||||
.def_property(
|
||||
"frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty)
|
||||
.def_property("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength,
|
||||
&tle::SamplingConfig::setPromptIgnoreLength)
|
||||
.def_property("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty)
|
||||
.def_property("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping)
|
||||
.def_property("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize,
|
||||
&tle::SamplingConfig::setNoRepeatNgramSize)
|
||||
.def_property("num_return_sequences", &tle::SamplingConfig::getNumReturnSequences,
|
||||
&tle::SamplingConfig::setNumReturnSequences)
|
||||
.def_property("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP)
|
||||
.def_property(
|
||||
"beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray)
|
||||
.def(py::pickle(samplingConfigGetstate, samplingConfigSetstate));
|
||||
|
||||
auto additionalModelOutputGetstate
|
||||
= [](tle::AdditionalModelOutput const& self) { return py::make_tuple(self.name, self.gatherContext); };
|
||||
auto additionalModelOutputSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Invalid AdditionalModelOutput state!");
|
||||
}
|
||||
return tle::AdditionalModelOutput(state[0].cast<std::string>(), state[1].cast<bool>());
|
||||
};
|
||||
py::class_<tle::AdditionalModelOutput>(m, "AdditionalModelOutput")
|
||||
.def(py::init<std::string, bool>(), py::arg("name"), py::arg("gather_context") = false)
|
||||
.def_readwrite("name", &tle::AdditionalModelOutput::name)
|
||||
.def_readwrite("gather_context", &tle::AdditionalModelOutput::gatherContext)
|
||||
.def(py::pickle(additionalModelOutputGetstate, additionalModelOutputSetstate));
|
||||
|
||||
auto outputConfigGetstate = [](tle::OutputConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.returnLogProbs, self.returnContextLogits, self.returnGenerationLogits,
|
||||
self.excludeInputFromOutput, self.returnEncoderOutput, self.returnPerfMetrics, self.additionalModelOutputs);
|
||||
};
|
||||
auto outputConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 7)
|
||||
{
|
||||
throw std::runtime_error("Invalid OutputConfig state!");
|
||||
}
|
||||
return tle::OutputConfig(state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<bool>(),
|
||||
state[3].cast<bool>(), state[4].cast<bool>(), state[5].cast<bool>(),
|
||||
state[6].cast<std::optional<std::vector<tle::AdditionalModelOutput>>>());
|
||||
};
|
||||
py::class_<tle::OutputConfig>(m, "OutputConfig")
|
||||
.def(py::init<bool, bool, bool, bool, bool, bool, std::optional<std::vector<tle::AdditionalModelOutput>>>(),
|
||||
py::arg("return_log_probs") = false, py::arg("return_context_logits") = false,
|
||||
py::arg("return_generation_logits") = false, py::arg("exclude_input_from_output") = false,
|
||||
py::arg("return_encoder_output") = false, py::arg("return_perf_metrics") = false,
|
||||
py::arg("additional_model_outputs") = py::none())
|
||||
.def_readwrite("return_log_probs", &tle::OutputConfig::returnLogProbs)
|
||||
.def_readwrite("return_context_logits", &tle::OutputConfig::returnContextLogits)
|
||||
.def_readwrite("return_generation_logits", &tle::OutputConfig::returnGenerationLogits)
|
||||
.def_readwrite("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput)
|
||||
.def_readwrite("return_encoder_output", &tle::OutputConfig::returnEncoderOutput)
|
||||
.def_readwrite("return_perf_metrics", &tle::OutputConfig::returnPerfMetrics)
|
||||
.def_readwrite("additional_model_outputs", &tle::OutputConfig::additionalModelOutputs)
|
||||
.def(py::pickle(outputConfigGetstate, outputConfigSetstate));
|
||||
|
||||
auto externalDraftTokensConfigGetstate = [](tle::ExternalDraftTokensConfig const& self)
|
||||
{ return py::make_tuple(self.getTokens(), self.getLogits(), self.getAcceptanceThreshold()); };
|
||||
auto externalDraftTokensConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 3)
|
||||
{
|
||||
throw std::runtime_error("Invalid ExternalDraftTokensConfig state!");
|
||||
}
|
||||
return tle::ExternalDraftTokensConfig(state[0].cast<VecTokens>(), state[1].cast<std::optional<Tensor>>(),
|
||||
state[2].cast<std::optional<FloatType>>());
|
||||
};
|
||||
py::class_<tle::ExternalDraftTokensConfig>(m, "ExternalDraftTokensConfig")
|
||||
.def(py::init<VecTokens, std::optional<Tensor>, std::optional<FloatType> const&, std::optional<bool>>(),
|
||||
py::arg("tokens"), py::arg("logits") = py::none(), py::arg("acceptance_threshold") = py::none(),
|
||||
py::arg("fast_logits") = py::none())
|
||||
.def_property_readonly("tokens", &tle::ExternalDraftTokensConfig::getTokens)
|
||||
.def_property_readonly("logits", &tle::ExternalDraftTokensConfig::getLogits)
|
||||
.def_property_readonly("acceptance_threshold", &tle::ExternalDraftTokensConfig::getAcceptanceThreshold)
|
||||
.def(py::pickle(externalDraftTokensConfigGetstate, externalDraftTokensConfigSetstate))
|
||||
.def_property_readonly("fast_logits", &tle::ExternalDraftTokensConfig::getFastLogits);
|
||||
|
||||
auto promptTuningConfigGetstate = [](tle::PromptTuningConfig const& self)
|
||||
{ return py::make_tuple(self.getEmbeddingTable(), self.getInputTokenExtraIds()); };
|
||||
auto promptTuningConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Invalid PromptTuningConfig state!");
|
||||
}
|
||||
return tle::PromptTuningConfig(state[0].cast<Tensor>(), state[1].cast<std::optional<VecTokenExtraIds>>());
|
||||
};
|
||||
py::class_<tle::PromptTuningConfig>(m, "PromptTuningConfig")
|
||||
.def(py::init<Tensor, std::optional<VecTokenExtraIds>>(), py::arg("embedding_table"),
|
||||
py::arg("input_token_extra_ids") = py::none())
|
||||
.def_property_readonly("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable)
|
||||
.def_property_readonly("input_token_extra_ids", &tle::PromptTuningConfig::getInputTokenExtraIds)
|
||||
.def(py::pickle(promptTuningConfigGetstate, promptTuningConfigSetstate));
|
||||
|
||||
auto loraConfigGetstate = [](tle::LoraConfig const& self)
|
||||
{ return py::make_tuple(self.getTaskId(), self.getWeights(), self.getConfig()); };
|
||||
auto loraConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 3)
|
||||
{
|
||||
throw std::runtime_error("Invalid LoraConfig state!");
|
||||
}
|
||||
return tle::LoraConfig(
|
||||
state[0].cast<IdType>(), state[1].cast<std::optional<Tensor>>(), state[2].cast<std::optional<Tensor>>());
|
||||
};
|
||||
py::class_<tle::LoraConfig>(m, "LoraConfig")
|
||||
.def(py::init<uint64_t, std::optional<Tensor>, std::optional<Tensor>>(), py::arg("task_id"),
|
||||
py::arg("weights") = py::none(), py::arg("config") = py::none())
|
||||
.def_property_readonly("task_id", &tle::LoraConfig::getTaskId)
|
||||
.def_property_readonly("weights", &tle::LoraConfig::getWeights)
|
||||
.def_property_readonly("config", &tle::LoraConfig::getConfig)
|
||||
.def(py::pickle(loraConfigGetstate, loraConfigSetstate));
|
||||
|
||||
auto multimodalInputGetstate = [](tle::MultimodalInput const& self)
|
||||
{ return py::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths()); };
|
||||
auto multimodalInputSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 3)
|
||||
{
|
||||
throw std::runtime_error("Invalid MultimodalInput state!");
|
||||
}
|
||||
return tle::MultimodalInput(state[0].cast<std::vector<std::vector<SizeType32>>>(),
|
||||
state[1].cast<std::vector<SizeType32>>(), state[2].cast<std::vector<SizeType32>>());
|
||||
};
|
||||
py::class_<tle::MultimodalInput>(m, "MultimodalInput")
|
||||
.def(py::init<std::vector<std::vector<SizeType32>>, std::vector<SizeType32>, std::vector<SizeType32>>(),
|
||||
py::arg("multimodal_hashes"), py::arg("multimodal_positions"), py::arg("multimodal_lengths"))
|
||||
.def_property_readonly("multimodal_hashes", &tle::MultimodalInput::getMultimodalHashes)
|
||||
.def_property_readonly("multimodal_positions", &tle::MultimodalInput::getMultimodalPositions)
|
||||
.def_property_readonly("multimodal_lengths", &tle::MultimodalInput::getMultimodalLengths)
|
||||
.def(py::pickle(multimodalInputGetstate, multimodalInputSetstate));
|
||||
|
||||
auto MropeConfigGetstate = [](tle::MropeConfig const& self)
|
||||
{ return py::make_tuple(self.getMRopeRotaryCosSin(), self.getMRopePositionDeltas()); };
|
||||
auto MropeConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Invalid MropeConfig state!");
|
||||
}
|
||||
return tle::MropeConfig(state[0].cast<tle::Tensor>(), state[1].cast<SizeType32>());
|
||||
};
|
||||
py::class_<tle::MropeConfig>(m, "MropeConfig")
|
||||
.def(py::init<Tensor, SizeType32>(), py::arg("mrope_rotary_cos_sin"), py::arg("mrope_position_deltas"))
|
||||
.def_property_readonly("mrope_rotary_cos_sin", &tle::MropeConfig::getMRopeRotaryCosSin)
|
||||
.def_property_readonly("mrope_position_deltas", &tle::MropeConfig::getMRopePositionDeltas)
|
||||
.def(py::pickle(MropeConfigGetstate, MropeConfigSetstate));
|
||||
|
||||
auto lookaheadDecodingConfigGetstate = [](tle::LookaheadDecodingConfig const& self)
|
||||
{ return py::make_tuple(self.getWindowSize(), self.getNgramSize(), self.getVerificationSetSize()); };
|
||||
auto lookaheadDecodingConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 3)
|
||||
{
|
||||
throw std::runtime_error("Invalid LookaheadDecodingConfig state!");
|
||||
}
|
||||
return tle::LookaheadDecodingConfig(
|
||||
state[0].cast<SizeType32>(), state[1].cast<SizeType32>(), state[2].cast<SizeType32>());
|
||||
};
|
||||
py::class_<tle::LookaheadDecodingConfig>(m, "LookaheadDecodingConfig")
|
||||
.def(py::init<SizeType32, SizeType32, SizeType32>(), py::arg("max_window_size"), py::arg("max_ngram_size"),
|
||||
py::arg("max_verification_set_size"))
|
||||
.def_property_readonly("max_window_size", &tle::LookaheadDecodingConfig::getWindowSize)
|
||||
.def_property_readonly("max_ngram_size", &tle::LookaheadDecodingConfig::getNgramSize)
|
||||
.def_property_readonly("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize)
|
||||
.def("calculate_speculative_resource", &tle::LookaheadDecodingConfig::calculateSpeculativeResource)
|
||||
.def_static(
|
||||
"calculate_speculative_resource_tuple", &tle::LookaheadDecodingConfig::calculateSpeculativeResourceTuple)
|
||||
.def(py::pickle(lookaheadDecodingConfigGetstate, lookaheadDecodingConfigSetstate))
|
||||
.def_static("get_default_lookahead_decoding_window",
|
||||
[]() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingWindow; })
|
||||
.def_static("get_default_lookahead_decoding_ngram",
|
||||
[]() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingNgram; })
|
||||
.def_static("get_default_lookahead_decoding_verification_set",
|
||||
[]() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingVerificationSet; });
|
||||
|
||||
auto TokenRangeRetentionConfigGetstate = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig const& self)
|
||||
{ return py::make_tuple(self.tokenStart, self.tokenEnd, self.priority, self.durationMs); };
|
||||
auto TokenRangeRetentionConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 4)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::KvCacheRetentionConfig::TokenRangeRetentionConfig(state[0].cast<SizeType32>(),
|
||||
state[1].cast<std::optional<SizeType32>>(), state[2].cast<tle::RetentionPriority>(),
|
||||
state[3].cast<std::optional<std::chrono::milliseconds>>());
|
||||
};
|
||||
auto kvCacheRetentionConfigGetstate = [](tle::KvCacheRetentionConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.getTokenRangeRetentionConfigs(), self.getDecodeRetentionPriority(),
|
||||
self.getDecodeDurationMs(), self.getTransferMode(), self.getDirectory());
|
||||
};
|
||||
auto kvCacheRetentionConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 5)
|
||||
{
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return tle::KvCacheRetentionConfig(
|
||||
state[0].cast<std::vector<tle::KvCacheRetentionConfig::TokenRangeRetentionConfig>>(),
|
||||
state[1].cast<tle::RetentionPriority>(), state[2].cast<std::optional<std::chrono::milliseconds>>(),
|
||||
state[3].cast<tle::KvCacheTransferMode>(), state[4].cast<std::string>());
|
||||
};
|
||||
|
||||
auto kvCacheRetentionConfig = py::class_<tle::KvCacheRetentionConfig>(m, "KvCacheRetentionConfig");
|
||||
|
||||
py::class_<tle::KvCacheRetentionConfig::TokenRangeRetentionConfig>(
|
||||
kvCacheRetentionConfig, "TokenRangeRetentionConfig")
|
||||
.def(py::init<SizeType32, std::optional<SizeType32>, tle::RetentionPriority,
|
||||
std::optional<std::chrono::milliseconds>>(),
|
||||
py::arg("token_start"), py::arg("token_end"), py::arg("priority"), py::arg("duration_ms") = py::none())
|
||||
.def_readwrite("token_start", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenStart)
|
||||
.def_readwrite("token_end", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenEnd)
|
||||
.def_readwrite("priority", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::priority)
|
||||
.def_readwrite("duration_ms", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::durationMs)
|
||||
.def(py::pickle(TokenRangeRetentionConfigGetstate, TokenRangeRetentionConfigSetstate))
|
||||
.def("__eq__", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::operator==);
|
||||
|
||||
// There's a circular dependency between the declaration of the TokenRangeRetentionPriority and
|
||||
// KvCacheRetentionConfig bindings. Defer definition of the KvCacheRetentionConfig bindings until the
|
||||
// TokenRangeRetentionPriority bindings have been defined.
|
||||
kvCacheRetentionConfig
|
||||
.def(py::init<std::vector<tle::KvCacheRetentionConfig::TokenRangeRetentionConfig>, tle::RetentionPriority,
|
||||
std::optional<std::chrono::milliseconds>, tle::KvCacheTransferMode, std::string>(),
|
||||
py::arg("token_range_retention_configs"),
|
||||
py::arg("decode_retention_priority") = tle::KvCacheRetentionConfig::kDefaultRetentionPriority,
|
||||
py::arg("decode_duration_ms") = py::none(),
|
||||
py::arg_v("transfer_mode", tle::KvCacheTransferMode::DRAM, "DRAM"), py::arg("directory") = py::none())
|
||||
.def_property_readonly(
|
||||
"token_range_retention_configs", &tle::KvCacheRetentionConfig::getTokenRangeRetentionConfigs)
|
||||
.def_property_readonly("decode_retention_priority", &tle::KvCacheRetentionConfig::getDecodeRetentionPriority)
|
||||
.def_property_readonly("decode_duration_ms", &tle::KvCacheRetentionConfig::getDecodeDurationMs)
|
||||
.def_property_readonly("transfer_mode", &tle::KvCacheRetentionConfig::getTransferMode)
|
||||
.def_property_readonly("directory", &tle::KvCacheRetentionConfig::getDirectory)
|
||||
.def(py::pickle(kvCacheRetentionConfigGetstate, kvCacheRetentionConfigSetstate))
|
||||
.def("__eq__", &tle::KvCacheRetentionConfig::operator==);
|
||||
|
||||
auto ContextPhaseParamsGetState = [](tle::ContextPhaseParams const& self)
|
||||
{
|
||||
if (self.getState() != nullptr)
|
||||
{
|
||||
auto serializedState = self.getSerializedState();
|
||||
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(),
|
||||
py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getCtxDpRank(),
|
||||
self.getDisaggInfoEndpoint());
|
||||
}
|
||||
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens(),
|
||||
self.getCtxDpRank(), self.getDisaggInfoEndpoint());
|
||||
};
|
||||
|
||||
auto ContextPhaseParamsSetState = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 6)
|
||||
{
|
||||
throw std::runtime_error("Invalid ContextPhaseParams state!");
|
||||
}
|
||||
if (!state[2].is_none())
|
||||
{
|
||||
auto opaque_state = state[2].cast<py::bytes>();
|
||||
auto opaque_state_str_view = std::string_view(opaque_state.cast<std::string_view>());
|
||||
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
|
||||
state[1].cast<tle::ContextPhaseParams::RequestIdType>(),
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
||||
state[3].cast<std::optional<VecTokens>>(), state[4].cast<std::optional<SizeType32>>(),
|
||||
state[5].cast<std::optional<std::string>>());
|
||||
}
|
||||
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
|
||||
state[1].cast<tle::ContextPhaseParams::RequestIdType>(), state[3].cast<std::optional<VecTokens>>(),
|
||||
state[4].cast<std::optional<SizeType32>>(), state[5].cast<std::optional<std::string>>());
|
||||
};
|
||||
|
||||
py::class_<tle::ContextPhaseParams>(m, "ContextPhaseParams")
|
||||
.def(py::init(
|
||||
[](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id,
|
||||
std::optional<py::bytes> const& opaque_state, std::optional<VecTokens> const& draft_tokens,
|
||||
std::optional<SizeType32> const& ctx_dp_rank,
|
||||
std::optional<std::string> const& disagg_info_endpoint)
|
||||
{
|
||||
if (opaque_state)
|
||||
{
|
||||
auto opaque_state_str_view = std::string_view(opaque_state.value().cast<std::string_view>());
|
||||
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id,
|
||||
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
||||
draft_tokens, ctx_dp_rank, disagg_info_endpoint);
|
||||
}
|
||||
return std::make_unique<tle::ContextPhaseParams>(
|
||||
first_gen_tokens, req_id, draft_tokens, ctx_dp_rank, disagg_info_endpoint);
|
||||
}),
|
||||
py::arg("first_gen_tokens"), py::arg("req_id"), py::arg("opaque_state") = py::none(),
|
||||
py::arg("draft_tokens") = py::none(), py::arg("ctx_dp_rank") = py::none(),
|
||||
py::arg("disagg_info_endpoint") = py::none())
|
||||
.def_property("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens,
|
||||
&tle::ContextPhaseParams::setFirstGenTokens)
|
||||
.def_property(
|
||||
"draft_tokens", &tle::ContextPhaseParams::getDraftTokens, &tle::ContextPhaseParams::setDraftTokens)
|
||||
.def_property("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId)
|
||||
.def_property("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank)
|
||||
.def_property("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint,
|
||||
&tle::ContextPhaseParams::setDisaggInfoEndpoint)
|
||||
.def_property_readonly("opaque_state",
|
||||
[](tle::ContextPhaseParams const& self)
|
||||
{
|
||||
std::optional<py::bytes> opaque_state{std::nullopt};
|
||||
if (self.getState() != nullptr)
|
||||
{
|
||||
auto serializedState = self.getSerializedState();
|
||||
opaque_state = py::bytes(serializedState.data(), serializedState.size());
|
||||
}
|
||||
return opaque_state;
|
||||
})
|
||||
.def(py::pickle(ContextPhaseParamsGetState, ContextPhaseParamsSetState));
|
||||
|
||||
auto EagleDecodingConfigGetstate = [](tle::EagleConfig const& self)
|
||||
{
|
||||
return py::make_tuple(self.getEagleChoices(), self.isGreedySampling(), self.getPosteriorThreshold(),
|
||||
self.useDynamicTree(), self.getDynamicTreeMaxTopK());
|
||||
};
|
||||
auto EagleDecodingConfigSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 5)
|
||||
{
|
||||
throw std::runtime_error("Invalid EagleConfig state!");
|
||||
}
|
||||
return tle::EagleConfig(state[0].cast<std::optional<tle::EagleChoices>>(), state[1].cast<bool>(),
|
||||
state[2].cast<std::optional<float>>(), state[3].cast<bool>(), state[4].cast<std::optional<SizeType32>>());
|
||||
};
|
||||
py::class_<tle::EagleConfig>(m, "EagleConfig")
|
||||
.def(py::init<std::optional<tle::EagleChoices>, bool, std::optional<float>, bool, std::optional<SizeType32>>(),
|
||||
py::arg("eagle_choices") = py::none(), py::arg("greedy_sampling") = true,
|
||||
py::arg("posterior_threshold") = py::none(), py::arg("use_dynamic_tree") = false,
|
||||
py::arg("dynamic_tree_max_topK") = py::none())
|
||||
.def_property_readonly("eagle_choices", &tle::EagleConfig::getEagleChoices)
|
||||
.def_property_readonly("greedy_sampling", &tle::EagleConfig::isGreedySampling)
|
||||
.def_property_readonly("posterior_threshold", &tle::EagleConfig::getPosteriorThreshold)
|
||||
.def_property_readonly("use_dynamic_tree", &tle::EagleConfig::useDynamicTree)
|
||||
.def_property_readonly("dynamic_tree_max_topK", &tle::EagleConfig::getDynamicTreeMaxTopK)
|
||||
.def(py::pickle(EagleDecodingConfigGetstate, EagleDecodingConfigSetstate));
|
||||
|
||||
// Guided decoding params
|
||||
auto pyGuidedDecodingParams = py::class_<tle::GuidedDecodingParams>(m, "GuidedDecodingParams");
|
||||
|
||||
py::enum_<tle::GuidedDecodingParams::GuideType>(pyGuidedDecodingParams, "GuideType")
|
||||
.value("JSON", tle::GuidedDecodingParams::GuideType::kJSON)
|
||||
.value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA)
|
||||
.value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX)
|
||||
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR)
|
||||
.value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG);
|
||||
|
||||
auto guidedDecodingParamsGetstate
|
||||
= [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); };
|
||||
|
||||
auto guidedDecodingParamsSetstate = [](py::tuple state)
|
||||
{
|
||||
if (state.size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Invalid GuidedDecodingParams state!");
|
||||
}
|
||||
return tle::GuidedDecodingParams(
|
||||
state[0].cast<tle::GuidedDecodingParams::GuideType>(), state[1].cast<std::optional<std::string>>());
|
||||
};
|
||||
|
||||
pyGuidedDecodingParams
|
||||
.def(py::init<tle::GuidedDecodingParams::GuideType, std::optional<std::string>>(), py::arg("guide_type"),
|
||||
py::arg("guide") = py::none())
|
||||
.def_property_readonly("guide_type", &tle::GuidedDecodingParams::getGuideType)
|
||||
.def_property_readonly("guide", &tle::GuidedDecodingParams::getGuide)
|
||||
.def(py::pickle(guidedDecodingParamsGetstate, guidedDecodingParamsSetstate));
|
||||
|
||||
auto requestGetstate = [](tle::Request const& self)
|
||||
{
|
||||
return py::make_tuple(self.getInputTokenIds(), self.getMaxTokens(), self.getStreaming(),
|
||||
self.getSamplingConfig(), self.getOutputConfig(), self.getEndId(), self.getPadId(), self.getPositionIds(),
|
||||
self.getBadWords(), self.getStopWords(), self.getEmbeddingBias(), self.getExternalDraftTokensConfig(),
|
||||
self.getPromptTuningConfig(), self.getMultimodalInput(), self.getMultimodalEmbedding(),
|
||||
self.getMropeConfig(), self.getLoraConfig(), self.getLookaheadConfig(), self.getKvCacheRetentionConfig(),
|
||||
self.getLogitsPostProcessorName(), self.getLogitsPostProcessor(), self.getEncoderInputTokenIds(),
|
||||
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
|
||||
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
|
||||
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
|
||||
self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId());
|
||||
};
|
||||
auto requestSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 35)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
return std::make_unique<tle::Request>(state[0].cast<VecTokens>(), state[1].cast<SizeType32>(),
|
||||
state[2].cast<bool>(), state[3].cast<tle::SamplingConfig>(), state[4].cast<tle::OutputConfig>(),
|
||||
state[5].cast<std::optional<SizeType32>>(), state[6].cast<std::optional<SizeType32>>(),
|
||||
state[7].cast<std::optional<std::vector<SizeType32>>>(),
|
||||
state[8].cast<std::optional<std::list<VecTokens>>>(), state[9].cast<std::optional<std::list<VecTokens>>>(),
|
||||
state[10].cast<std::optional<Tensor>>(), state[11].cast<std::optional<tle::ExternalDraftTokensConfig>>(),
|
||||
state[12].cast<std::optional<tle::PromptTuningConfig>>(),
|
||||
state[13].cast<std::optional<tle::MultimodalInput>>(), state[14].cast<std::optional<Tensor>>(),
|
||||
state[15].cast<std::optional<tle::MropeConfig>>(), state[16].cast<std::optional<tle::LoraConfig>>(),
|
||||
state[17].cast<std::optional<tle::LookaheadDecodingConfig>>(),
|
||||
state[18].cast<std::optional<tle::KvCacheRetentionConfig>>(), state[19].cast<std::optional<std::string>>(),
|
||||
state[20].cast<std::optional<tle::LogitsPostProcessor>>(), state[21].cast<std::optional<VecTokens>>(),
|
||||
state[22].cast<std::optional<IdType>>(), state[23].cast<bool>(), state[24].cast<tle::PriorityType>(),
|
||||
state[25].cast<tle::RequestType>(), state[26].cast<std::optional<tle::ContextPhaseParams>>(),
|
||||
state[27].cast<std::optional<tle::Tensor>>(), state[28].cast<std::optional<SizeType32>>(),
|
||||
state[29].cast<std::optional<tle::Tensor>>(), 1, state[30].cast<std::optional<tle::EagleConfig>>(),
|
||||
state[31].cast<std::optional<tle::Tensor>>(), state[32].cast<std::optional<tle::GuidedDecodingParams>>(),
|
||||
std::nullopt, std::nullopt, state[33].cast<std::optional<tle::CacheSaltIDType>>(),
|
||||
state[34].cast<std::optional<tle::IdType>>());
|
||||
};
|
||||
|
||||
py::class_<tle::Request> request(m, "Request", pybind11::dynamic_attr());
|
||||
request
|
||||
.def(py::init<tle::VecTokens, // inputTokenIds
|
||||
tle::SizeType32, // maxTokens
|
||||
bool, // streaming
|
||||
tle::SamplingConfig const&, // samplingConfig
|
||||
tle::OutputConfig const&, // outputConfig
|
||||
std::optional<tle::SizeType32> const&, // endId
|
||||
std::optional<tle::SizeType32> const&, // padId
|
||||
std::optional<std::vector<SizeType32>>, // positionIds
|
||||
std::optional<std::list<tle::VecTokens>>, // badWords
|
||||
std::optional<std::list<tle::VecTokens>>, // stopWords
|
||||
std::optional<tle::Tensor>, // embeddingBias
|
||||
std::optional<tle::ExternalDraftTokensConfig>, // externalDraftTokensConfig
|
||||
std::optional<tle::PromptTuningConfig>, // pTuningConfig
|
||||
std::optional<tle::MultimodalInput>, // multimodalInput
|
||||
std::optional<tle::Tensor>, // multimodalEmbedding
|
||||
std::optional<tle::MropeConfig>, // mRopeConfig
|
||||
std::optional<tle::LoraConfig>, // loraConfig
|
||||
std::optional<tle::LookaheadDecodingConfig>, // lookaheadConfig
|
||||
std::optional<tle::KvCacheRetentionConfig>, // kvCacheRetentionConfig
|
||||
std::optional<std::string>, // logitsPostProcessorName
|
||||
std::optional<tle::LogitsPostProcessor>, // logitsPostProcessor
|
||||
std::optional<tle::VecTokens>, // encoderInputTokenIds
|
||||
std::optional<tle::IdType>, // clientId
|
||||
bool, // returnAllGeneratedTokens
|
||||
tle::PriorityType, // priority
|
||||
tle::RequestType, // type
|
||||
std::optional<tle::ContextPhaseParams>, // contextPhaseParams
|
||||
std::optional<tle::Tensor>, // encoderInputFeatures
|
||||
std::optional<tle::SizeType32>, // encoderOutputLength
|
||||
std::optional<tle::Tensor>, // crossAttentionMask
|
||||
SizeType32, // numReturnSequences
|
||||
std::optional<tle::EagleConfig>, // eagleConfig
|
||||
std::optional<tle::Tensor>, // skipCrossAttnBlocks
|
||||
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
|
||||
std::optional<tle::SizeType32>, // languageAdapterUid
|
||||
std::optional<tle::MillisecondsType>, // allottedTimeMs
|
||||
std::optional<tle::CacheSaltIDType>, // cacheSaltID
|
||||
std::optional<tle::IdType> // disaggRequestId
|
||||
>(),
|
||||
// clang-format off
|
||||
py::arg("input_token_ids"),
|
||||
py::arg("max_tokens"),
|
||||
py::kw_only(),
|
||||
py::arg("streaming") = false,
|
||||
py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"),
|
||||
py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"),
|
||||
py::arg("end_id") = py::none(),
|
||||
py::arg("pad_id") = py::none(),
|
||||
py::arg("position_ids") = py::none(),
|
||||
py::arg("bad_words") = py::none(),
|
||||
py::arg("stop_words") = py::none(),
|
||||
py::arg("embedding_bias") = py::none(),
|
||||
py::arg("external_draft_tokens_config") = py::none(),
|
||||
py::arg("prompt_tuning_config") = py::none(),
|
||||
py::arg("multimodal_input") = py::none(),
|
||||
py::arg("multimodal_embedding") = py::none(),
|
||||
py::arg("mrope_config") = py::none(),
|
||||
py::arg("lora_config") = py::none(),
|
||||
py::arg("lookahead_config") = py::none(),
|
||||
py::arg("kv_cache_retention_config") = py::none(),
|
||||
py::arg("logits_post_processor_name") = py::none(),
|
||||
py::arg("logits_post_processor") = py::none(),
|
||||
py::arg("encoder_input_token_ids") = py::none(),
|
||||
py::arg("client_id") = py::none(),
|
||||
py::arg("return_all_generated_tokens") = false,
|
||||
py::arg("priority") = tle::Request::kDefaultPriority,
|
||||
py::arg_v("type", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION,
|
||||
"RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION"),
|
||||
py::arg("context_phase_params") = py::none(),
|
||||
py::arg("encoder_input_features") = py::none(),
|
||||
py::arg("encoder_output_length") = py::none(),
|
||||
py::arg("cross_attention_mask") = py::none(),
|
||||
py::arg("num_return_sequences") = 1,
|
||||
py::arg("eagle_config") = py::none(),
|
||||
py::arg("skip_cross_attn_blocks") = py::none(),
|
||||
py::arg("guided_decoding_params") = py::none(),
|
||||
py::arg("language_adapter_uid") = py::none(),
|
||||
py::arg("allotted_time_ms") = py::none(),
|
||||
py::arg("cache_salt_id") = py::none(),
|
||||
py::arg("disagg_request_id") = py::none()
|
||||
) // clang-format on
|
||||
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
|
||||
.def_property_readonly("max_tokens", &tle::Request::getMaxTokens)
|
||||
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
||||
.def_property("sampling_config", &tle::Request::getSamplingConfig, &tle::Request::setSamplingConfig)
|
||||
.def_property("output_config", &tle::Request::getOutputConfig, &tle::Request::setOutputConfig)
|
||||
.def_property("end_id", &tle::Request::getEndId, &tle::Request::setEndId)
|
||||
.def_property("pad_id", &tle::Request::getPadId, &tle::Request::setPadId)
|
||||
.def_property("position_ids", &tle::Request::getPositionIds, &tle::Request::setPositionIds)
|
||||
.def_property("bad_words", &tle::Request::getBadWords, &tle::Request::setBadWords)
|
||||
.def_property("stop_words", &tle::Request::getStopWords, &tle::Request::setStopWords)
|
||||
.def_property("embedding_bias", &tle::Request::getEmbeddingBias, &tle::Request::setEmbeddingBias)
|
||||
.def_property("external_draft_tokens_config", &tle::Request::getExternalDraftTokensConfig,
|
||||
&tle::Request::setExternalDraftTokensConfig)
|
||||
.def_property(
|
||||
"prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig)
|
||||
.def_property("multimodal_input", &tle::Request::getMultimodalInput, &tle::Request::setMultimodalInput)
|
||||
.def_property(
|
||||
"multimodal_embedding", &tle::Request::getMultimodalEmbedding, &tle::Request::setMultimodalEmbedding)
|
||||
.def_property("mrope_config", &tle::Request::getMropeConfig, &tle::Request::setMropeConfig)
|
||||
.def_property("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig)
|
||||
.def_property("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig)
|
||||
.def_property("kv_cache_retention_config", &tle::Request::getKvCacheRetentionConfig,
|
||||
&tle::Request::setKvCacheRetentionConfig)
|
||||
.def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName,
|
||||
&tle::Request::setLogitsPostProcessorName)
|
||||
.def_property(
|
||||
"logits_post_processor", &tle::Request::getLogitsPostProcessor, &tle::Request::setLogitsPostProcessor)
|
||||
.def_property(
|
||||
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds)
|
||||
.def_property("client_id", &tle::Request::getClientId, &tle::Request::setClientId)
|
||||
.def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens,
|
||||
&tle::Request::setReturnAllGeneratedTokens)
|
||||
.def_property("request_type", &tle::Request::getRequestType, &tle::Request::setRequestType)
|
||||
.def_property(
|
||||
"encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures)
|
||||
.def_property(
|
||||
"cross_attention_mask", &tle::Request::getCrossAttentionMask, &tle::Request::setCrossAttentionMask)
|
||||
.def_property("eagle_config", &tle::Request::getEagleConfig, &tle::Request::setEagleConfig)
|
||||
.def_property(
|
||||
"skip_cross_attn_blocks", &tle::Request::getSkipCrossAttnBlocks, &tle::Request::setSkipCrossAttnBlocks)
|
||||
.def_property(
|
||||
"guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams)
|
||||
.def_property("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs)
|
||||
.def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
|
||||
.def_property(
|
||||
"context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
|
||||
.def_property("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId)
|
||||
.def(py::pickle(requestGetstate, requestSetstate));
|
||||
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
|
||||
|
||||
py::class_<tle::SpeculativeDecodingFastLogitsInfo>(m, "SpeculativeDecodingFastLogitsInfo")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("draft_request_id", &tle::SpeculativeDecodingFastLogitsInfo::draftRequestId)
|
||||
.def_readwrite("draft_participant_id", &tle::SpeculativeDecodingFastLogitsInfo::draftParticipantId)
|
||||
.def("to_tensor", &tle::SpeculativeDecodingFastLogitsInfo::toTensor);
|
||||
|
||||
auto requestPerfMetrics = py::class_<tle::RequestPerfMetrics>(m, "RequestPerfMetrics");
|
||||
|
||||
auto timingMetricsGetstate = [](tle::RequestPerfMetrics::TimingMetrics const& self)
|
||||
{
|
||||
return py::make_tuple(self.arrivalTime, self.firstScheduledTime, self.firstTokenTime, self.lastTokenTime,
|
||||
self.kvCacheTransferStart, self.kvCacheTransferEnd, self.kvCacheSize);
|
||||
};
|
||||
auto timingMetricsSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 7)
|
||||
{
|
||||
throw std::runtime_error("Invalid TimingMetrics state!");
|
||||
}
|
||||
return tle::RequestPerfMetrics::TimingMetrics{state[0].cast<tle::RequestPerfMetrics::TimePoint>(),
|
||||
state[1].cast<tle::RequestPerfMetrics::TimePoint>(), state[2].cast<tle::RequestPerfMetrics::TimePoint>(),
|
||||
state[3].cast<tle::RequestPerfMetrics::TimePoint>(), state[4].cast<tle::RequestPerfMetrics::TimePoint>(),
|
||||
state[5].cast<tle::RequestPerfMetrics::TimePoint>(), state[6].cast<size_t>()};
|
||||
};
|
||||
py::class_<tle::RequestPerfMetrics::TimingMetrics>(m, "TimingMetrics")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("arrival_time", &tle::RequestPerfMetrics::TimingMetrics::arrivalTime)
|
||||
.def_readwrite("first_scheduled_time", &tle::RequestPerfMetrics::TimingMetrics::firstScheduledTime)
|
||||
.def_readwrite("first_token_time", &tle::RequestPerfMetrics::TimingMetrics::firstTokenTime)
|
||||
.def_readwrite("last_token_time", &tle::RequestPerfMetrics::TimingMetrics::lastTokenTime)
|
||||
.def_readwrite("kv_cache_transfer_start", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferStart)
|
||||
.def_readwrite("kv_cache_transfer_end", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferEnd)
|
||||
.def_readwrite("kv_cache_size", &tle::RequestPerfMetrics::TimingMetrics::kvCacheSize)
|
||||
.def(py::pickle(timingMetricsGetstate, timingMetricsSetstate));
|
||||
|
||||
auto kvCacheMetricsGetstate = [](tle::RequestPerfMetrics::KvCacheMetrics const& self)
|
||||
{
|
||||
return py::make_tuple(self.numTotalAllocatedBlocks, self.numNewAllocatedBlocks, self.numReusedBlocks,
|
||||
self.numMissedBlocks, self.kvCacheHitRate);
|
||||
};
|
||||
auto kvCacheMetricsSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 5)
|
||||
{
|
||||
throw std::runtime_error("Invalid KvCacheMetrics state!");
|
||||
}
|
||||
return tle::RequestPerfMetrics::KvCacheMetrics{state[0].cast<SizeType32>(), state[1].cast<SizeType32>(),
|
||||
state[2].cast<SizeType32>(), state[3].cast<SizeType32>(), state[4].cast<float>()};
|
||||
};
|
||||
py::class_<tle::RequestPerfMetrics::KvCacheMetrics>(m, "KvCacheMetrics")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("num_total_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numTotalAllocatedBlocks)
|
||||
.def_readwrite("num_new_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numNewAllocatedBlocks)
|
||||
.def_readwrite("num_reused_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numReusedBlocks)
|
||||
.def_readwrite("num_missed_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numMissedBlocks)
|
||||
.def_readwrite("kv_cache_hit_rate", &tle::RequestPerfMetrics::KvCacheMetrics::kvCacheHitRate)
|
||||
.def(py::pickle(kvCacheMetricsGetstate, kvCacheMetricsSetstate));
|
||||
|
||||
auto speculativeDecodingMetricsGetstate = [](tle::RequestPerfMetrics::SpeculativeDecodingMetrics const& self)
|
||||
{ return py::make_tuple(self.acceptanceRate, self.totalAcceptedDraftTokens, self.totalDraftTokens); };
|
||||
auto speculativeDecodingMetricsSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 3)
|
||||
{
|
||||
throw std::runtime_error("Invalid SpeculativeDecodingMetrics state!");
|
||||
}
|
||||
return tle::RequestPerfMetrics::SpeculativeDecodingMetrics{
|
||||
state[0].cast<float>(), state[1].cast<SizeType32>(), state[2].cast<SizeType32>()};
|
||||
};
|
||||
|
||||
py::class_<tle::RequestPerfMetrics::SpeculativeDecodingMetrics>(m, "SpeculativeDecodingMetrics")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("acceptance_rate", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::acceptanceRate)
|
||||
.def_readwrite("total_accepted_draft_tokens",
|
||||
&tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalAcceptedDraftTokens)
|
||||
.def_readwrite("total_draft_tokens", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalDraftTokens)
|
||||
.def(py::pickle(speculativeDecodingMetricsGetstate, speculativeDecodingMetricsSetstate));
|
||||
|
||||
auto requestPerfMetricsGetstate = [](tle::RequestPerfMetrics const& self)
|
||||
{
|
||||
return py::make_tuple(self.timingMetrics, self.kvCacheMetrics, self.speculativeDecoding, self.firstIter,
|
||||
self.lastIter, self.iter);
|
||||
};
|
||||
auto requestPerfMetricsSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 6)
|
||||
{
|
||||
throw std::runtime_error("Invalid RequestPerfMetrics state!");
|
||||
}
|
||||
return tle::RequestPerfMetrics{state[0].cast<tle::RequestPerfMetrics::TimingMetrics>(),
|
||||
state[1].cast<tle::RequestPerfMetrics::KvCacheMetrics>(),
|
||||
state[2].cast<tle::RequestPerfMetrics::SpeculativeDecodingMetrics>(),
|
||||
state[3].cast<std::optional<tle::IterationType>>(), state[4].cast<std::optional<tle::IterationType>>(),
|
||||
state[5].cast<std::optional<tle::IterationType>>()};
|
||||
};
|
||||
|
||||
// There's a circular dependency between the declaration of the TimingMetrics and RequestPerfMetrics bindings.
|
||||
// Defer definition of the RequestPerfMetrics bindings until the TimingMetrics have been defined.
|
||||
requestPerfMetrics.def(py::init<>())
|
||||
.def_readwrite("timing_metrics", &tle::RequestPerfMetrics::timingMetrics)
|
||||
.def_readwrite("kv_cache_metrics", &tle::RequestPerfMetrics::kvCacheMetrics)
|
||||
.def_readwrite("speculative_decoding", &tle::RequestPerfMetrics::speculativeDecoding)
|
||||
.def_readwrite("first_iter", &tle::RequestPerfMetrics::firstIter)
|
||||
.def_readwrite("last_iter", &tle::RequestPerfMetrics::lastIter)
|
||||
.def_readwrite("iter", &tle::RequestPerfMetrics::iter)
|
||||
.def(py::pickle(requestPerfMetricsGetstate, requestPerfMetricsSetstate));
|
||||
|
||||
py::class_<tle::AdditionalOutput>(m, "AdditionalOutput")
|
||||
.def(py::init([](std::string const& name, tle::Tensor const& output)
|
||||
{ return std::make_unique<tle::AdditionalOutput>(name, output); }))
|
||||
.def_readwrite("name", &tle::AdditionalOutput::name)
|
||||
.def_readwrite("output", &tle::AdditionalOutput::output);
|
||||
|
||||
auto resultSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 14)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
tle::Result result;
|
||||
result.isFinal = state[0].cast<bool>();
|
||||
result.outputTokenIds = state[1].cast<std::vector<VecTokens>>();
|
||||
result.cumLogProbs = state[2].cast<std::optional<std::vector<float>>>();
|
||||
result.logProbs = state[3].cast<std::optional<std::vector<std::vector<float>>>>();
|
||||
result.contextLogits = state[4].cast<std::optional<Tensor>>();
|
||||
result.generationLogits = state[5].cast<std::optional<Tensor>>();
|
||||
result.encoderOutput = state[6].cast<std::optional<Tensor>>();
|
||||
result.finishReasons = state[7].cast<std::vector<tle::FinishReason>>();
|
||||
result.sequenceIndex = state[8].cast<SizeType32>();
|
||||
result.isSequenceFinal = state[9].cast<bool>();
|
||||
result.decodingIter = state[10].cast<SizeType32>();
|
||||
result.avgDecodedTokensPerIter = state[11].cast<float>();
|
||||
result.contextPhaseParams = state[12].cast<std::optional<tle::ContextPhaseParams>>();
|
||||
result.requestPerfMetrics = state[13].cast<std::optional<tle::RequestPerfMetrics>>();
|
||||
return std::make_unique<tle::Result>(result);
|
||||
};
|
||||
|
||||
auto resultGetstate = [](tle::Result const& self)
|
||||
{
|
||||
return py::make_tuple(self.isFinal, self.outputTokenIds, self.cumLogProbs, self.logProbs, self.contextLogits,
|
||||
self.generationLogits, self.encoderOutput, self.finishReasons, self.sequenceIndex, self.isSequenceFinal,
|
||||
self.decodingIter, self.avgDecodedTokensPerIter, self.contextPhaseParams, self.requestPerfMetrics);
|
||||
};
|
||||
|
||||
py::class_<tle::Result>(m, "Result")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("is_final", &tle::Result::isFinal)
|
||||
.def_readwrite("output_token_ids", &tle::Result::outputTokenIds)
|
||||
.def_readwrite("cum_log_probs", &tle::Result::cumLogProbs)
|
||||
.def_readwrite("log_probs", &tle::Result::logProbs)
|
||||
.def_readwrite("context_logits", &tle::Result::contextLogits)
|
||||
.def_readwrite("generation_logits", &tle::Result::generationLogits)
|
||||
.def_readwrite("spec_dec_fast_logits_info", &tle::Result::specDecFastLogitsInfo)
|
||||
.def_readwrite("encoder_output", &tle::Result::encoderOutput)
|
||||
.def_readwrite("finish_reasons", &tle::Result::finishReasons)
|
||||
.def_readwrite("sequence_index", &tle::Result::sequenceIndex)
|
||||
.def_readwrite("is_sequence_final", &tle::Result::isSequenceFinal)
|
||||
.def_readwrite("decoding_iter", &tle::Result::decodingIter)
|
||||
.def_readwrite("avg_decoded_tokens_per_iter", &tle::Result::avgDecodedTokensPerIter)
|
||||
.def_readwrite("context_phase_params", &tle::Result::contextPhaseParams)
|
||||
.def_readwrite("request_perf_metrics", &tle::Result::requestPerfMetrics)
|
||||
.def_readwrite("additional_outputs", &tle::Result::additionalOutputs)
|
||||
.def(py::pickle(resultGetstate, resultSetstate));
|
||||
|
||||
m.def("deserialize_result",
|
||||
[](std::string& x)
|
||||
{
|
||||
std::istringstream is(x);
|
||||
return tle::serialize_utils::deserialize<tle::Result>(is);
|
||||
});
|
||||
|
||||
auto responseGetstate = [](tle::Response const& self)
|
||||
{ return py::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); };
|
||||
|
||||
auto responseSetstate = [](py::tuple const& state)
|
||||
{
|
||||
if (state.size() != 3)
|
||||
{
|
||||
throw std::runtime_error("Invalid Request state!");
|
||||
}
|
||||
return std::make_unique<tle::Response>(
|
||||
state[0].cast<IdType>(), state[1].cast<tle::Result>(), state[2].cast<IdType>());
|
||||
};
|
||||
|
||||
py::class_<tle::Response>(m, "Response")
|
||||
.def(py::init<IdType, std::string, std::optional<IdType>>(), py::arg("request_id"), py::arg("error_msg"),
|
||||
py::arg("client_id") = std::nullopt)
|
||||
.def(py::init<IdType, tle::Result, std::optional<IdType>>(), py::arg("request_id"), py::arg("result"),
|
||||
py::arg("client_id") = std::nullopt)
|
||||
.def_property_readonly("request_id", &tle::Response::getRequestId)
|
||||
.def_property_readonly("client_id", &tle::Response::getClientId)
|
||||
.def("has_error", &tle::Response::hasError)
|
||||
.def_property_readonly("error_msg", &tle::Response::getErrorMsg)
|
||||
.def_property_readonly("result", &tle::Response::getResult)
|
||||
.def("clear_context_logits",
|
||||
[](tle::Response& self)
|
||||
{
|
||||
if (!self.hasError())
|
||||
{
|
||||
auto& result = const_cast<tle::Result&>(self.getResult());
|
||||
result.contextLogits.reset();
|
||||
}
|
||||
})
|
||||
.def("clear_generation_logits",
|
||||
[](tle::Response& self)
|
||||
{
|
||||
if (!self.hasError())
|
||||
{
|
||||
auto& result = const_cast<tle::Result&>(self.getResult());
|
||||
result.generationLogits.reset();
|
||||
}
|
||||
})
|
||||
.def(py::pickle(responseGetstate, responseSetstate));
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::executor
|
||||
@ -1,29 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::executor
|
||||
{
|
||||
|
||||
// Register bindings for executor API.
|
||||
void initRequestBindings(pybind11::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::executor
|
||||
@ -1,40 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
|
||||
#include "tensorrt_llm/common/bindingUtils.h"
|
||||
#include "tensorrt_llm/runtime/utils/pgUtils.h"
|
||||
|
||||
namespace tensorrt_llm::pybind::process_group
|
||||
{
|
||||
|
||||
void initBindings(py::module_& m)
|
||||
{
|
||||
|
||||
m.def("init_pg",
|
||||
[](py::object world_pg_obj, py::object local_pg_obj, std::string const& pybind11_abi)
|
||||
{
|
||||
using Pg = c10d::ProcessGroup;
|
||||
using E = py::error_already_set;
|
||||
|
||||
pg_utils::init_pg(common::get_intrusive_ptr<Pg, E>(world_pg_obj.ptr(), pybind11_abi),
|
||||
common::get_intrusive_ptr<Pg, E>(local_pg_obj.ptr(), pybind11_abi));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::process_group
|
||||
@ -1,27 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::pybind::process_group
|
||||
{
|
||||
void initBindings(py::module_& m);
|
||||
} // namespace tensorrt_llm::pybind::process_group
|
||||
@ -1,510 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
#include "hostfunc.h"
|
||||
#include "moeBindings.h"
|
||||
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
|
||||
#include "tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h"
|
||||
#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h"
|
||||
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
|
||||
#include "tensorrt_llm/kernels/delayStream.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/decoderState.h"
|
||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||
#include "tensorrt_llm/runtime/decodingOutput.h"
|
||||
#include "tensorrt_llm/runtime/gptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/ipcUtils.h"
|
||||
#include "tensorrt_llm/runtime/lookaheadBuffers.h"
|
||||
#include "tensorrt_llm/runtime/loraCache.h"
|
||||
#include "tensorrt_llm/runtime/mcastGPUBuffer.h"
|
||||
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
|
||||
#include "tensorrt_llm/runtime/tllmRuntime.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include "tensorrt_llm/runtime/virtualMemory.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace te = tensorrt_llm::executor;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
class PyITensor : public tensorrt_llm::runtime::ITensor
|
||||
{
|
||||
public:
|
||||
/* Inherit the constructors */
|
||||
using ITensor::ITensor;
|
||||
|
||||
[[nodiscard]] void* data() override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(void*, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
data /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
[[nodiscard]] void const* data() const override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(void const*, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
data /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t getSize() const override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(std::size_t, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
getSize /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t getCapacity() const override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(std::size_t, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
getCapacity /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
[[nodiscard]] DataType getDataType() const override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(DataType, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
getDataType /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
[[nodiscard]] tr::MemoryType getMemoryType() const override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(tr::MemoryType, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
getMemoryType /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
[[nodiscard]] char const* getMemoryTypeName() const override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(char const*, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
getMemoryTypeName /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
virtual void resize(std::size_t newSize) override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(void, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
resize /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
void release() override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(void, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
release /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
[[nodiscard]] Shape const& getShape() const override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(Shape const&, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
getShape /* Name of function in C++ (must match Python name) */
|
||||
/* Argument(s) */
|
||||
);
|
||||
}
|
||||
|
||||
void reshape(Shape const& dims) override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(void, /* Return type */
|
||||
ITensor, /* Parent class */
|
||||
reshape, /* Name of function in C++ (must match Python name) */
|
||||
dims /* Argument(s) */
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
class PyIGptDecoder : public tr::IGptDecoder
|
||||
{
|
||||
public:
|
||||
using tr::IGptDecoder::IGptDecoder; // Inherit constructors
|
||||
|
||||
void setup(tr::SamplingConfig const& samplingConfig, size_t batchSize,
|
||||
tr::DecodingInput::TensorConstPtr const& batchSlots,
|
||||
std::optional<tr::DecodingOutput> const& output = std::nullopt,
|
||||
std::optional<nvinfer1::DataType> explicitDraftTokensDType = std::nullopt,
|
||||
std::optional<std::vector<tr::ITensor::SharedConstPtr>> const& lookaheadPrompt = std::nullopt,
|
||||
std::optional<std::vector<te::LookaheadDecodingConfig>> const& lookaheadAlgoConfigs = std::nullopt) override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, setup, samplingConfig, batchSize, batchSlots, output,
|
||||
explicitDraftTokensDType, lookaheadPrompt, lookaheadAlgoConfigs);
|
||||
}
|
||||
|
||||
void forwardAsync(tr::DecodingOutput& output, tr::DecodingInput const& input) override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, forwardAsync, output, input);
|
||||
}
|
||||
|
||||
void forwardSync(tr::DecodingOutput& output, tr::DecodingInput const& input) override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, forwardSync, output, input);
|
||||
}
|
||||
|
||||
tr::SamplingConfig const& getSamplingConfig() override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(tr::SamplingConfig const&, IGptDecoder, getSamplingConfig);
|
||||
}
|
||||
|
||||
void disableLookahead(std::optional<tr::SamplingConfig> const& samplingConfig, tr::SizeType32 batchSize,
|
||||
tr::DecodingInput::TensorConstPtr batchSlots) override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, disableLookahead, samplingConfig, batchSize, batchSlots);
|
||||
}
|
||||
};
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m)
|
||||
{
|
||||
py::classh<tr::ITensor, PyITensor>(m, "ITensor").def(py::init());
|
||||
py::class_<tr::LoraCache::TaskLayerModuleConfig>(m, "TaskLayerModuleConfig")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("page_id", &tr::LoraCache::TaskLayerModuleConfig::pageId)
|
||||
.def_readwrite("slot_idx", &tr::LoraCache::TaskLayerModuleConfig::slotIdx)
|
||||
.def_readwrite("in_size", &tr::LoraCache::TaskLayerModuleConfig::inSize)
|
||||
.def_readwrite("out_size", &tr::LoraCache::TaskLayerModuleConfig::outSize)
|
||||
.def_readwrite("module_id", &tr::LoraCache::TaskLayerModuleConfig::moduleId)
|
||||
.def_readwrite("layer_id", &tr::LoraCache::TaskLayerModuleConfig::layerId)
|
||||
.def_readwrite("adapter_size", &tr::LoraCache::TaskLayerModuleConfig::adapterSize)
|
||||
.def_readwrite("num_slots", &tr::LoraCache::TaskLayerModuleConfig::numSlots)
|
||||
.def_readwrite("weights_in_pointer", &tr::LoraCache::TaskLayerModuleConfig::weightsInPointer)
|
||||
.def_readwrite("weights_out_pointer", &tr::LoraCache::TaskLayerModuleConfig::weightsOutPointer)
|
||||
.def_readwrite("scaling_vec_pointer", &tr::LoraCache::TaskLayerModuleConfig::scalingVecPointer)
|
||||
.def(py::self == py::self);
|
||||
|
||||
py::class_<tr::CudaVirtualMemoryManager>(m, "CudaVirtualMemoryManager")
|
||||
.def("release_with_tag", &tr::CudaVirtualMemoryManager::releaseWithTag, py::arg("tag"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("materialize_with_tag", &tr::CudaVirtualMemoryManager::materializeWithTag, py::arg("tag"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::classh<tr::TllmRuntime>(m, "TllmRuntime")
|
||||
.def(py::init(
|
||||
[](std::filesystem::path engine_path, float gpu_weights_percent = 1.0f, bool use_shape_inference = true)
|
||||
{
|
||||
// Using default logger by passing nullptr
|
||||
return new tr::TllmRuntime(
|
||||
tr::RawEngine(engine_path), nullptr, gpu_weights_percent, use_shape_inference);
|
||||
}))
|
||||
.def(py::init(
|
||||
[](py::buffer engine_buffer, float gpu_weights_percent = 1.0f, bool use_shape_inference = true)
|
||||
{
|
||||
py::buffer_info info = engine_buffer.request();
|
||||
if (info.ndim != 1)
|
||||
throw std::runtime_error("Expected 1-D array for engine buffer");
|
||||
return new tr::TllmRuntime(
|
||||
tr::RawEngine(info.ptr, info.shape[0]), nullptr, gpu_weights_percent, use_shape_inference);
|
||||
}))
|
||||
.def_property_readonly("num_contexts", &tr::TllmRuntime::getNbContexts)
|
||||
.def_property_readonly("num_profiles", &tr::TllmRuntime::getNbProfiles)
|
||||
.def("get_opt_profile_id", &tr::TllmRuntime::getOptProfileId, py::arg("num_tokens"), py::arg("split_points"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("clear_contexts", &tr::TllmRuntime::clearContexts, py::call_guard<py::gil_scoped_release>())
|
||||
.def("execute_context", &tr::TllmRuntime::executeContext, py::arg("context_id"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("stream_ptr", &tr::TllmRuntime::getStreamPtr)
|
||||
.def_property_readonly("buffer_manager",
|
||||
static_cast<tr::BufferManager& (tr::TllmRuntime::*) ()>(&tr::TllmRuntime::getBufferManager))
|
||||
.def("set_layer_profiler", &tr::TllmRuntime::setLayerProfiler, py::call_guard<py::gil_scoped_release>())
|
||||
.def("has_layer_profiler", &tr::TllmRuntime::hasLayerProfiler, py::arg("context_id"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("layer_profiler_info", &tr::TllmRuntime::getLayerProfileInfo)
|
||||
.def("report_to_profiler", &tr::TllmRuntime::reportToProfiler, py::arg("context_id"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("logits_dtype_from_engine",
|
||||
[](tr::TllmRuntime& self) { return self.getEngine().getTensorDataType("logits"); });
|
||||
|
||||
py::class_<tr::LookaheadDecodingBuffers>(m, "LookaheadDecodingBuffers")
|
||||
.def(py::init<tr::SizeType32, tr::SizeType32, tr::BufferManager const&>(), py::arg("max_num_sequences"),
|
||||
py::arg("max_tokens_per_step"), py::arg("buffer_manager"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_readwrite("generation_lengths", &tr::LookaheadDecodingBuffers::generationLengths)
|
||||
.def_readwrite("position_offsets", &tr::LookaheadDecodingBuffers::positionOffsets)
|
||||
.def_readwrite("packed_masks", &tr::LookaheadDecodingBuffers::packedMasks)
|
||||
.def_readwrite("position_ids", &tr::LookaheadDecodingBuffers::positionIds);
|
||||
|
||||
py::class_<tr::ExplicitDraftTokensBuffers::Inputs>(m, "ExplicitDraftTokensBuffersInputs")
|
||||
.def("create", &tr::ExplicitDraftTokensBuffers::Inputs::create, py::arg("max_num_sequences"),
|
||||
py::arg("runtime"), py::arg("model_config"), py::arg("world_config"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_readwrite("temperatures", &tr::ExplicitDraftTokensBuffers::Inputs::temperatures)
|
||||
.def_readwrite("position_ids_base", &tr::ExplicitDraftTokensBuffers::Inputs::positionIdsBase)
|
||||
.def_readwrite("generation_lengths", &tr::ExplicitDraftTokensBuffers::Inputs::generationLengths)
|
||||
.def_readwrite("random_data_sample", &tr::ExplicitDraftTokensBuffers::Inputs::randomDataSample)
|
||||
.def_readwrite("random_data_validation", &tr::ExplicitDraftTokensBuffers::Inputs::randomDataValidation)
|
||||
.def_readwrite("draft_tokens", &tr::ExplicitDraftTokensBuffers::Inputs::draftTokens)
|
||||
.def_readwrite("draft_indices", &tr::ExplicitDraftTokensBuffers::Inputs::draftIndices)
|
||||
.def_readwrite("draft_probs", &tr::ExplicitDraftTokensBuffers::Inputs::draftProbs)
|
||||
.def_readwrite("packed_masks", &tr::ExplicitDraftTokensBuffers::Inputs::packedMasks)
|
||||
.def_readwrite("position_ids", &tr::ExplicitDraftTokensBuffers::Inputs::positionIds)
|
||||
.def_readwrite("max_gen_length_host", &tr::ExplicitDraftTokensBuffers::Inputs::maxGenLengthHost)
|
||||
.def_readwrite("generation_lengths_host", &tr::ExplicitDraftTokensBuffers::Inputs::generationLengthsHost);
|
||||
|
||||
py::class_<tr::DecodingInput>(m, "DecodingInput");
|
||||
py::class_<tr::DecodingOutput>(m, "DecodingOutput");
|
||||
|
||||
py::class_<tr::CudaEvent>(m, "CudaEvent")
|
||||
.def(py::init<unsigned int>(), py::arg("flags") = cudaEventDisableTiming,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("synchronize", &tr::CudaEvent::synchronize, py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::class_<tr::IGptDecoder, PyIGptDecoder>(m, "IGptDecoder")
|
||||
.def(
|
||||
"setup",
|
||||
[](tr::IGptDecoder& self, tr::SamplingConfig const& samplingConfig, size_t batchSize,
|
||||
at::Tensor const& batchSlots, std::optional<tr::DecodingOutput> const& output = std::nullopt,
|
||||
std::optional<nvinfer1::DataType> explicitDraftTokensDType = std::nullopt,
|
||||
std::optional<std::vector<tr::ITensor::SharedConstPtr>> const& lookaheadPrompt = std::nullopt,
|
||||
std::optional<std::vector<te::LookaheadDecodingConfig>> const& lookaheadAlgoConfigs = std::nullopt)
|
||||
{
|
||||
auto tensorPtrBatchSlots = tr::TorchView::of(batchSlots);
|
||||
self.setup(samplingConfig, batchSize, std::move(tensorPtrBatchSlots), output, explicitDraftTokensDType,
|
||||
lookaheadPrompt, lookaheadAlgoConfigs);
|
||||
},
|
||||
py::arg("sampling_config"), py::arg("batch_size"), py::arg("batch_slots"), py::arg("output") = std::nullopt,
|
||||
py::arg("explicit_draft_tokens_d_type") = std::nullopt, py::arg("lookahead_prompt") = std::nullopt,
|
||||
py::arg("lookahead_algo_configs") = std::nullopt, py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::class_<tr::decoder::DecoderState>(m, "DecoderState")
|
||||
.def(py::init<>(), py::call_guard<py::gil_scoped_release>())
|
||||
.def("setup", &tr::decoder::DecoderState::setup, py::arg("max_num_sequences"), py::arg("max_beam_width"),
|
||||
py::arg("max_attention_window"), py::arg("sink_token_length"), py::arg("max_sequence_length"),
|
||||
py::arg("dtype"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, py::arg("max_num_sequences"),
|
||||
py::arg("max_beam_width"), py::arg("max_attention_window"), py::arg("buffer_manager"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("setup_speculative_decoding", &tr::decoder::DecoderState::setupSpeculativeDecoding,
|
||||
py::arg("speculative_decoding_mode"), py::arg("max_tokens_per_engine_step"), py::arg("dtype"),
|
||||
py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("joint_decoding_input", &tr::decoder::DecoderState::getJointDecodingInput)
|
||||
.def_property_readonly("joint_decoding_output", &tr::decoder::DecoderState::getJointDecodingOutput)
|
||||
.def_property_readonly("cache_indirection_input", &tr::decoder::DecoderState::getCacheIndirectionInput)
|
||||
.def_property_readonly("cache_indirection_output", &tr::decoder::DecoderState::getCacheIndirectionOutput)
|
||||
.def_property_readonly(
|
||||
"sequence_lengths", py::overload_cast<>(&tr::decoder::DecoderState::getSequenceLengths, py::const_))
|
||||
.def("get_sequence_lengths",
|
||||
py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getSequenceLengths, py::const_),
|
||||
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("all_new_tokens", &tr::decoder::DecoderState::getAllNewTokens)
|
||||
.def_property_readonly("finished_sum", &tr::decoder::DecoderState::getFinishedSum)
|
||||
.def_property_readonly("finish_reasons", &tr::decoder::DecoderState::getFinishReasons)
|
||||
.def_property_readonly("ids", py::overload_cast<>(&tr::decoder::DecoderState::getIds, py::const_))
|
||||
.def("get_ids", py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getIds, py::const_),
|
||||
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly(
|
||||
"gathered_ids", py::overload_cast<>(&tr::decoder::DecoderState::getGatheredIds, py::const_))
|
||||
.def("get_gathered_ids",
|
||||
py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getGatheredIds, py::const_),
|
||||
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("parent_ids", &tr::decoder::DecoderState::getParentIds)
|
||||
.def_property_readonly(
|
||||
"cum_log_probs", py::overload_cast<>(&tr::decoder::DecoderState::getCumLogProbs, py::const_))
|
||||
.def("get_cum_log_probs",
|
||||
py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getCumLogProbs, py::const_),
|
||||
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("log_probs", py::overload_cast<>(&tr::decoder::DecoderState::getLogProbs, py::const_))
|
||||
.def("get_log_probs", py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getLogProbs, py::const_),
|
||||
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("next_draft_tokens", &tr::decoder::DecoderState::getNextDraftTokens)
|
||||
.def_property_readonly("prev_draft_tokens_lengths", &tr::decoder::DecoderState::getPrevDraftTokensLengths)
|
||||
.def_property_readonly("next_draft_tokens_lengths", &tr::decoder::DecoderState::getNextDraftTokensLengths)
|
||||
.def_property_readonly("accepted_lengths_cum_sum", &tr::decoder::DecoderState::getAcceptedLengthsCumSum)
|
||||
.def_property_readonly("accepted_packed_paths", &tr::decoder::DecoderState::getAcceptedPackedPaths)
|
||||
.def_property_readonly("max_beam_width", &tr::decoder::DecoderState::getMaxBeamWidth)
|
||||
.def_property_readonly("max_sequence_length", &tr::decoder::DecoderState::getMaxSequenceLength)
|
||||
.def_property_readonly("max_decoding_decoder_tokens", &tr::decoder::DecoderState::getMaxDecodingDecoderTokens)
|
||||
.def_property_readonly("max_decoding_engine_tokens", &tr::decoder::DecoderState::getMaxDecodingEngineTokens)
|
||||
.def_property_readonly("num_decoding_engine_tokens",
|
||||
py::overload_cast<>(&tr::decoder::DecoderState::getNumDecodingEngineTokens, py::const_))
|
||||
.def("get_num_decoding_engine_tokens",
|
||||
py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getNumDecodingEngineTokens, py::const_),
|
||||
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("set_num_decoding_engine_tokens", &tr::decoder::DecoderState::setNumDecodingEngineTokens,
|
||||
py::arg("batch_idx"), py::arg("num_tokens"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode)
|
||||
.def_property("generation_steps", &tr::decoder::DecoderState::getGenerationSteps,
|
||||
&tr::decoder::DecoderState::setGenerationSteps);
|
||||
|
||||
py::class_<tr::GptDecoderBatched>(m, "GptDecoderBatched")
|
||||
.def(py::init<tr::GptDecoderBatched::CudaStreamPtr>(), py::arg("stream"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("setup", &tr::GptDecoderBatched::setup, py::arg("mode"), py::arg("max_num_sequences"),
|
||||
py::arg("max_beam_width"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, py::arg("decoder_state"), py::arg("input"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, py::return_value_policy::reference)
|
||||
.def("finalize", &tr::GptDecoderBatched::finalize, py::arg("decoder_state"), py::arg("batch_idx"),
|
||||
py::arg("sampling_config"), py::arg("streaming"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly(
|
||||
"decoder_stream",
|
||||
[](tr::GptDecoderBatched& self) -> tr::CudaStream const& { return *self.getDecoderStream(); },
|
||||
py::return_value_policy::reference);
|
||||
|
||||
m.def(
|
||||
"lamport_initialize_all",
|
||||
[](intptr_t buffer_0, intptr_t buffer_1, intptr_t buffer_2, size_t size)
|
||||
{
|
||||
tr::lamportInitializeAll(reinterpret_cast<void*>(buffer_0), reinterpret_cast<void*>(buffer_1),
|
||||
reinterpret_cast<void*>(buffer_2), size);
|
||||
},
|
||||
"Lamport initialize all buffers", py::call_guard<py::gil_scoped_release>());
|
||||
m.def(
|
||||
"lamport_initialize",
|
||||
[](intptr_t buffer, size_t size)
|
||||
{ tensorrt_llm::kernels::ar_fusion::lamport_initialize(reinterpret_cast<void*>(buffer), size, 0); },
|
||||
"Lmaport initialize buffer", py::call_guard<py::gil_scoped_release>());
|
||||
m.def(
|
||||
"delay_kernel",
|
||||
[](int64_t delay_micro_secs, py::object py_stream)
|
||||
{
|
||||
// Get the raw stream handle from PyTorch stream object
|
||||
auto stream_ptr = py_stream.attr("cuda_stream").cast<int64_t>();
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
py::gil_scoped_release release;
|
||||
tensorrt_llm::kernels::invokeDelayStreamKernel(delay_micro_secs, stream);
|
||||
},
|
||||
"Delay kernel launch on the default stream");
|
||||
m.def(
|
||||
"max_workspace_size_lowprecision",
|
||||
[](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); },
|
||||
"Calculate the maximum workspace size needed for low precision all-reduce operations",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::enum_<tr::CudaVirtualMemoryAllocator::RestoreMode>(m, "CudaVirtualMemoryAllocatorRestoreMode")
|
||||
.value("NONE", tr::CudaVirtualMemoryAllocator::RestoreMode::NONE)
|
||||
.value("CPU", tr::CudaVirtualMemoryAllocator::RestoreMode::CPU)
|
||||
.value("PINNED", tr::CudaVirtualMemoryAllocator::RestoreMode::PINNED)
|
||||
.value("MEMSET", tr::CudaVirtualMemoryAllocator::RestoreMode::MEMSET);
|
||||
|
||||
m.def("get_virtual_memory_manager", &tr::getVirtualMemoryManager, "Get the virtual memory manager",
|
||||
py::return_value_policy::reference);
|
||||
|
||||
m.def(
|
||||
"set_virtual_memory_allocator",
|
||||
[](std::string const& tag, tr::CudaVirtualMemoryAllocator::RestoreMode mode, uintptr_t stream)
|
||||
{
|
||||
static_assert(sizeof(uintptr_t) == sizeof(cudaStream_t));
|
||||
tr::setVirtualMemoryAllocator(tag, mode,
|
||||
std::make_shared<tr::CudaStream>(
|
||||
reinterpret_cast<cudaStream_t>(stream), tensorrt_llm::common::getDevice(), false));
|
||||
},
|
||||
"Set the virtual memory allocator and start allocating virtual memory for CUDA allocations",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def("clear_virtual_memory_allocator", &tr::clearVirtualMemoryAllocator,
|
||||
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
|
||||
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, bool, int64_t>(), py::arg("buf_size"),
|
||||
py::arg("group_size"), py::arg("group_rank"), py::arg("device_idx"), py::arg("mn_nvlink"),
|
||||
py::arg("mpi_comm_fortran_handle"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::enum_<tensorrt_llm::kernels::AllReduceFusionOp>(m, "AllReduceFusionOp")
|
||||
.value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE)
|
||||
.value("RESIDUAL_RMS_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM)
|
||||
.value("LAST_PROCESS_FOR_UB", tensorrt_llm::kernels::AllReduceFusionOp::LAST_PROCESS_FOR_UB)
|
||||
.value("RESIDUAL_RMS_PREPOST_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM)
|
||||
.value("RESIDUAL_RMS_NORM_QUANT_FP8", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8)
|
||||
.value("RESIDUAL_RMS_NORM_QUANT_NVFP4", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4)
|
||||
.value("RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4",
|
||||
tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
|
||||
.value("RESIDUAL_RMS_NORM_OUT_QUANT_FP8",
|
||||
tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8);
|
||||
|
||||
py::enum_<tensorrt_llm::kernels::AllReduceStrategyType>(m, "AllReduceStrategy")
|
||||
.value("NCCL", tensorrt_llm::kernels::AllReduceStrategyType::NCCL)
|
||||
.value("MIN_LATENCY", tensorrt_llm::kernels::AllReduceStrategyType::MIN_LATENCY)
|
||||
.value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO)
|
||||
.value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB)
|
||||
.value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT)
|
||||
.value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT)
|
||||
.value("NCCL_SYMMETRIC", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_SYMMETRIC);
|
||||
|
||||
// Initialize MoeLoadBalancer bindings
|
||||
initMoeBindings(m);
|
||||
// Initialize HostFunc bindings
|
||||
initHostFuncBindings(m);
|
||||
}
|
||||
|
||||
void initBindingsEarly(py::module_& m)
|
||||
{
|
||||
py::classh<tr::BufferManager>(m, "BufferManager")
|
||||
.def(py::init<tr::BufferManager::CudaStreamPtr, bool>(), py::arg("stream"), py::arg("trim_pool") = false,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("stream", &tr::BufferManager::getStream);
|
||||
|
||||
py::class_<tr::SpeculativeDecodingMode>(m, "SpeculativeDecodingMode")
|
||||
.def(py::init<tr::SpeculativeDecodingMode::UnderlyingType>(), py::arg("state"))
|
||||
.def_static("NoneType", &tr::SpeculativeDecodingMode::None)
|
||||
.def_static("DraftTokensExternal", &tr::SpeculativeDecodingMode::DraftTokensExternal)
|
||||
.def_static("Medusa", &tr::SpeculativeDecodingMode::Medusa)
|
||||
.def_static("Eagle", &tr::SpeculativeDecodingMode::Eagle)
|
||||
.def_static("LookaheadDecoding", &tr::SpeculativeDecodingMode::LookaheadDecoding)
|
||||
.def_static("ExplicitDraftTokens", &tr::SpeculativeDecodingMode::ExplicitDraftTokens)
|
||||
.def_property_readonly("is_none", &tr::SpeculativeDecodingMode::isNone)
|
||||
.def_property_readonly("is_draft_tokens_external", &tr::SpeculativeDecodingMode::isDraftTokensExternal)
|
||||
.def_property_readonly("is_medusa", &tr::SpeculativeDecodingMode::isMedusa)
|
||||
.def_property_readonly("is_eagle", &tr::SpeculativeDecodingMode::isEagle)
|
||||
.def_property_readonly("is_lookahead_decoding", &tr::SpeculativeDecodingMode::isLookaheadDecoding)
|
||||
.def_property_readonly("is_explicit_draft_tokens", &tr::SpeculativeDecodingMode::isExplicitDraftTokens)
|
||||
.def_property_readonly("updates_position_ids", &tr::SpeculativeDecodingMode::updatesPositionIds)
|
||||
.def_property_readonly("requires_attention_mask", &tr::SpeculativeDecodingMode::requiresAttentionMask)
|
||||
.def_property_readonly("predicts_draft_tokens", &tr::SpeculativeDecodingMode::predictsDraftTokens)
|
||||
.def_property_readonly("needs_kv_cache_rewind", &tr::SpeculativeDecodingMode::needsKVCacheRewind)
|
||||
.def_property_readonly("variable_draft_length", &tr::SpeculativeDecodingMode::variableDraftLength)
|
||||
.def_property_readonly("has_draft_logits", &tr::SpeculativeDecodingMode::hasDraftLogits)
|
||||
.def_property_readonly("needs_decoder_prologue", &tr::SpeculativeDecodingMode::needsDecoderPrologue);
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
@ -1,31 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
|
||||
void initBindings(py::module_& m);
|
||||
void initBindingsEarly(py::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
@ -1,120 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "hostfunc.h"
|
||||
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <memory>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
|
||||
struct HostFuncUserData
|
||||
{
|
||||
bool freeUserData;
|
||||
py::function pyHostFunc;
|
||||
py::tuple pyArgs;
|
||||
py::dict pyKwargs;
|
||||
|
||||
HostFuncUserData(bool freeUserData, py::function func, py::tuple args, py::dict kwargs)
|
||||
: freeUserData(freeUserData)
|
||||
, pyHostFunc(std::move(func))
|
||||
, pyArgs(std::move(args))
|
||||
, pyKwargs(std::move(kwargs))
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
static void cudaHostFuncTrampoline(void* userData)
|
||||
{
|
||||
// Acquire the GIL since we are calling Python code from a CUDA stream.
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
auto hostFuncUserData = std::unique_ptr<HostFuncUserData>(static_cast<HostFuncUserData*>(userData));
|
||||
try
|
||||
{
|
||||
hostFuncUserData->pyHostFunc(*hostFuncUserData->pyArgs, **hostFuncUserData->pyKwargs);
|
||||
}
|
||||
catch (py::error_already_set& e)
|
||||
{
|
||||
e.restore();
|
||||
PyErr_Print();
|
||||
}
|
||||
if (hostFuncUserData->freeUserData)
|
||||
{
|
||||
// If freeUserData is true, keep the ownership of the user data.
|
||||
TLLM_LOG_DEBUG("Automatically freeing hostfunc user data %p", hostFuncUserData.get());
|
||||
}
|
||||
else
|
||||
{
|
||||
// If freeUserData is false, release the ownership of the user data.
|
||||
hostFuncUserData.release();
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<uintptr_t> launchHostFunc(
|
||||
uintptr_t streamPtr, bool freeUserData, py::function pyHostFunc, py::args pyArgs, py::kwargs pyKwargs)
|
||||
{
|
||||
auto const stream = reinterpret_cast<cudaStream_t>(streamPtr);
|
||||
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
auto hostFuncUserData
|
||||
= std::make_unique<HostFuncUserData>(freeUserData, pyHostFunc, py::tuple(pyArgs), py::dict(pyKwargs));
|
||||
|
||||
py::gil_scoped_release release;
|
||||
|
||||
cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get());
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
throw std::runtime_error("Failed to launch host function.");
|
||||
}
|
||||
|
||||
// Release the ownership of the user data.
|
||||
// If freeUserData is true, the user data will be freed by cudaHostFuncTrampoline.
|
||||
// If freeUserData is false, the user data should be freed by freeHostFuncUserData.
|
||||
auto userDataPtr = reinterpret_cast<uintptr_t>(hostFuncUserData.release());
|
||||
return freeUserData ? std::nullopt : std::make_optional(userDataPtr);
|
||||
}
|
||||
|
||||
void freeHostFuncUserData(uintptr_t userDataPtr)
|
||||
{
|
||||
// Acquire the GIL to safely release the Python objects.
|
||||
py::gil_scoped_acquire gil;
|
||||
|
||||
// Create a unique_ptr to take over the ownership of the user data;
|
||||
// the user data is released when the unique_ptr is destroyed.
|
||||
auto hostFuncUserData = std::unique_ptr<HostFuncUserData>(reinterpret_cast<HostFuncUserData*>(userDataPtr));
|
||||
|
||||
TLLM_LOG_DEBUG("Manually freeing hostfunc user data %p", hostFuncUserData.get());
|
||||
}
|
||||
|
||||
void initHostFuncBindings(pybind11::module_& m)
|
||||
{
|
||||
m.def("launch_hostfunc", &launchHostFunc, "Launch a Python host function to a CUDA stream",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
@ -1,27 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
|
||||
void initHostFuncBindings(pybind11::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
@ -1,138 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "moeBindings.h"
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer/hostAccessibleDeviceAllocator.h"
|
||||
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <vector>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace tk = tensorrt_llm::kernels;
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
|
||||
void pyDoReplication(tk::MoeLoadBalanceMetaInfo const& metaInfo, std::vector<float>& expertLoadFactor,
|
||||
tr::MoePlacementCpuInfo* cpuPlacement)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
metaInfo.expertCount == expertLoadFactor.size(), "expert_count and expert_load_factor size mismatch");
|
||||
tr::doReplication(metaInfo, expertLoadFactor.data(), cpuPlacement);
|
||||
};
|
||||
|
||||
void pyDoPlacement(tk::MoeLoadBalanceMetaInfo const& metaInfo, std::vector<float>& expertLoadFactor,
|
||||
tr::MoePlacementCpuInfo* cpuPlacement)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
metaInfo.expertCount == expertLoadFactor.size(), "expert_count and expert_load_factor size mismatch");
|
||||
tr::doPlacement(metaInfo, expertLoadFactor.data(), cpuPlacement);
|
||||
};
|
||||
|
||||
void initMoeBindings(pybind11::module_& m)
|
||||
{
|
||||
// Bind MoeWeight struct
|
||||
py::class_<tr::MoeWeight>(m, "MoeWeight")
|
||||
.def(py::init<>())
|
||||
.def_property("weight_ptr", &tr::MoeWeight::getWeightPtr, &tr::MoeWeight::setWeightPtr)
|
||||
.def_readwrite("height", &tr::MoeWeight::mHeight)
|
||||
.def_readwrite("width", &tr::MoeWeight::mWidth)
|
||||
.def_readwrite("pitch", &tr::MoeWeight::mPitch)
|
||||
.def("__repr__",
|
||||
[](tr::MoeWeight const& self)
|
||||
{
|
||||
return "<MoeWeight ptr=" + std::to_string(self.getWeightPtr())
|
||||
+ " height=" + std::to_string(self.mHeight) + " width=" + std::to_string(self.mWidth)
|
||||
+ " pitch=" + std::to_string(self.mPitch) + ">";
|
||||
});
|
||||
|
||||
// Bind MoeLoadBalanceMetaInfo struct
|
||||
py::class_<tk::MoeLoadBalanceMetaInfo>(m, "MoeLoadBalanceMetaInfo")
|
||||
.def(py::init<int, int, int, int, int>(), py::arg("expert_count"), py::arg("top_k"), py::arg("ep_rank"),
|
||||
py::arg("ep_size"), py::arg("slot_count_per_rank"))
|
||||
.def_readwrite("expert_count", &tk::MoeLoadBalanceMetaInfo::expertCount)
|
||||
.def_readwrite("top_k", &tk::MoeLoadBalanceMetaInfo::topK)
|
||||
.def_readwrite("ep_rank", &tk::MoeLoadBalanceMetaInfo::epRank)
|
||||
.def_readwrite("ep_size", &tk::MoeLoadBalanceMetaInfo::epSize)
|
||||
.def_readwrite("slot_count_per_rank", &tk::MoeLoadBalanceMetaInfo::slotCountPerRank);
|
||||
|
||||
// Bind MoePlacementCpuInfo struct
|
||||
py::class_<tr::MoePlacementCpuInfo>(m, "MoePlacementCpuInfo")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("expert_replica_count", &tr::MoePlacementCpuInfo::expertReplicaCount)
|
||||
.def_readwrite("rank_expert_ids", &tr::MoePlacementCpuInfo::rankExpertIds);
|
||||
|
||||
// Bind SingleLayerMoeLoadBalancer class
|
||||
py::class_<tr::SingleLayerMoeLoadBalancer, std::shared_ptr<tr::SingleLayerMoeLoadBalancer>>(
|
||||
m, "SingleLayerMoeLoadBalancer")
|
||||
.def("add_single_weight_slot", &tr::SingleLayerMoeLoadBalancer::addSingleWeightSlot, py::arg("slot_id"),
|
||||
py::arg("name"), py::arg("weight_slot"), "Add a single weight slot for a specific slot ID",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("add_single_host_weight", &tr::SingleLayerMoeLoadBalancer::addSingleHostWeight, py::arg("expert_id"),
|
||||
py::arg("name"), py::arg("host_weight"), "Add a single host weight for a specific expert ID",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("set_initial_weight_assignments", &tr::SingleLayerMoeLoadBalancer::setInitialWeightAssignments,
|
||||
py::arg("initial_weight_assignments"), "Set initial weight assignments for each slot",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_pointer", &tr::SingleLayerMoeLoadBalancer::getSelfPtr,
|
||||
"Get the pointer of the SingleLayerMoeLoadBalancer", py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_layer_id", &tr::SingleLayerMoeLoadBalancer::getLayerId,
|
||||
"Get the layer id of the SingleLayerMoeLoadBalancer", py::call_guard<py::gil_scoped_release>())
|
||||
.def("get_old_rank_expert_ids", &tr::SingleLayerMoeLoadBalancer::getOldRankExpertIds,
|
||||
"Get the old rank expert ids of the SingleLayerMoeLoadBalancer", py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
// Bind MoeLoadBalancer class
|
||||
py::class_<tr::MoeLoadBalancer>(m, "MoeLoadBalancer")
|
||||
.def(py::init<int, int, int>(), py::arg("ep_rank"), py::arg("ep_size"), py::arg("layer_updates_per_iter"),
|
||||
"Initialize the MoeLoadBalancer with the specified expert parallel rank, size, and update frequency",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("set_use_gpu_memcpy", &tr::MoeLoadBalancer::setUseGpuMemcpy, py::arg("use_gpu_memcpy"),
|
||||
"Set whether to use GPU memcpy for weight updates", py::call_guard<py::gil_scoped_release>())
|
||||
.def("add_layer", &tr::MoeLoadBalancer::AddLayer, py::arg("expert_count"), py::arg("top_k"),
|
||||
py::arg("slot_count_per_rank"), "Add a new MOE layer to the load balancer",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("finalize_model", &tr::MoeLoadBalancer::finalizeModel,
|
||||
"Finalize the model structure, must be called after all layers are added",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("set_warm_up_iter_count", &tr::MoeLoadBalancer::setWarmUpIterCount, py::arg("iter_count"),
|
||||
"Set the number of warm-up iterations", py::call_guard<py::gil_scoped_release>())
|
||||
.def("start_iter", &tr::MoeLoadBalancer::startIter, py::arg("iter_id"), py::arg("enable_statistic"),
|
||||
py::arg("enable_update_weights"), "Start a new iteration with the given ID and settings",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("end_iter", &tr::MoeLoadBalancer::endIter, py::arg("iter_id"), "End the iteration with the given ID",
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("shutdown", &tr::MoeLoadBalancer::shutdown, "Shutdown the load balancer and clean up resources",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def("is_host_accessible_device_memory_supported", &tr::HostAccessibleDeviceAllocator::isSupported,
|
||||
"If current system support host accessible device memory");
|
||||
|
||||
// Bind do_replication function for testing
|
||||
m.def("do_replication", &pyDoReplication, py::arg("meta_info"), py::arg("expert_load_factor"),
|
||||
py::arg("cpu_placement"), "Do replication");
|
||||
|
||||
// Bind do_placement function for testing
|
||||
m.def("do_placement", &pyDoPlacement, py::arg("meta_info"), py::arg("expert_load_factor"), py::arg("cpu_placement"),
|
||||
"Do placement");
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
@ -1,27 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::runtime
|
||||
{
|
||||
|
||||
void initMoeBindings(pybind11::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::runtime
|
||||
@ -1,85 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "modelSpecBinding.h"
|
||||
#include "tensorrt_llm/testing/modelSpec.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
using tensorrt_llm::testing::ModelSpec;
|
||||
using tensorrt_llm::testing::KVCacheType;
|
||||
using tensorrt_llm::testing::QuantMethod;
|
||||
using tensorrt_llm::testing::OutputContentType;
|
||||
|
||||
namespace tensorrt_llm::pybind::testing
|
||||
{
|
||||
|
||||
void initBindings(py::module_& m)
|
||||
{
|
||||
py::enum_<QuantMethod>(m, "QuantMethod", py::arithmetic(), "Quantization Method")
|
||||
.value("NONE", QuantMethod::kNONE, "No Quantization")
|
||||
.value("SMOOTH_QUANT", QuantMethod::kSMOOTH_QUANT, "Smooth Quantization");
|
||||
|
||||
py::enum_<OutputContentType>(m, "OutputContentType", py::arithmetic(), "Output Content Type")
|
||||
.value("NONE", OutputContentType::kNONE, "No Output Content")
|
||||
.value("CONTEXT_LOGITS", OutputContentType::kCONTEXT_LOGITS, "Context Logits")
|
||||
.value("GENERATION_LOGITS", OutputContentType::kGENERATION_LOGITS, "Generation Logits")
|
||||
.value("LOG_PROBS", OutputContentType::kLOG_PROBS, "Log Probs")
|
||||
.value("CUM_LOG_PROBS", OutputContentType::kCUM_LOG_PROBS, "Cumulative Log");
|
||||
|
||||
py::class_<ModelSpec>(m, "ModelSpec")
|
||||
.def(py::init<std::string const&, nvinfer1::DataType>())
|
||||
.def("use_gpt_plugin", &ModelSpec::useGptAttentionPlugin)
|
||||
.def("use_packed_input", &ModelSpec::usePackedInput)
|
||||
.def("set_kv_cache_type", &ModelSpec::setKVCacheType)
|
||||
.def("use_decoder_per_request", &ModelSpec::useDecoderPerRequest)
|
||||
.def("use_tensor_parallelism", &ModelSpec::useTensorParallelism)
|
||||
.def("use_pipeline_parallelism", &ModelSpec::usePipelineParallelism)
|
||||
.def("use_context_parallelism", &ModelSpec::useContextParallelism)
|
||||
.def("set_draft_tokens", &ModelSpec::setDraftTokens)
|
||||
.def("use_accept_by_logits", &ModelSpec::useAcceptByLogits)
|
||||
.def("use_mamba_plugin", &ModelSpec::useMambaPlugin)
|
||||
.def("gather_logits", &ModelSpec::gatherLogits)
|
||||
.def("replace_logits", &ModelSpec::replaceLogits)
|
||||
.def("return_log_probs", &ModelSpec::returnLogProbs)
|
||||
.def("smoke_test", &ModelSpec::smokeTest)
|
||||
.def("use_medusa", &ModelSpec::useMedusa)
|
||||
.def("use_eagle", &ModelSpec::useEagle)
|
||||
.def("use_lookahead_decoding", &ModelSpec::useLookaheadDecoding)
|
||||
.def("use_explicit_draft_tokens_decoding", &ModelSpec::useExplicitDraftTokensDecoding)
|
||||
.def("use_draft_tokens_external_decoding", &ModelSpec::useDraftTokensExternalDecoding)
|
||||
.def("use_logits", &ModelSpec::useLogits)
|
||||
.def("use_multiple_profiles", &ModelSpec::useMultipleProfiles)
|
||||
.def("set_max_input_length", &ModelSpec::setMaxInputLength)
|
||||
.def("set_max_output_length", &ModelSpec::setMaxOutputLength)
|
||||
.def("set_quant_method", &ModelSpec::setQuantMethod)
|
||||
.def("use_lora_plugin", &ModelSpec::useLoraPlugin)
|
||||
.def("get_input_file", &ModelSpec::getInputFile)
|
||||
.def("get_model_path", &ModelSpec::getModelPath)
|
||||
.def("get_results_file", &ModelSpec::getResultsFile)
|
||||
.def("get_generation_logits_file", &ModelSpec::getGenerationLogitsFile)
|
||||
.def("get_context_logits_file", &ModelSpec::getContextLogitsFile)
|
||||
.def("get_cum_log_probs_file", &ModelSpec::getCumLogProbsFile)
|
||||
.def("get_log_probs_file", &ModelSpec::getLogProbsFile)
|
||||
.def("enable_context_fmha_fp32_acc", &ModelSpec::enableContextFMHAFp32Acc)
|
||||
.def("get_enable_context_fmha_fp32_acc", &ModelSpec::getEnableContextFMHAFp32Acc)
|
||||
.def("__copy__", [](ModelSpec const& self) { return ModelSpec(self); });
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::testing
|
||||
@ -1,30 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::pybind::testing
|
||||
{
|
||||
|
||||
void initBindings(py::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::testing
|
||||
@ -1,83 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <tensorrt_llm/kernels/helixAllToAll.h>
|
||||
#include <tensorrt_llm/thop/attentionOp.h>
|
||||
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::pybind::thop
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m)
|
||||
{
|
||||
// Export MoE A2A constants
|
||||
for (auto const& kv : torch_ext::moe_comm::getMoeA2AMetaInfoIndexPairs())
|
||||
{
|
||||
m.attr(kv.first) = py::int_(kv.second);
|
||||
}
|
||||
|
||||
m.def("attention", &torch_ext::attention,
|
||||
// Parameters with default values using std::nullopt for optional arguments
|
||||
py::arg("q"), py::arg("k") = std::nullopt, py::arg("v") = std::nullopt, py::arg("output"),
|
||||
py::arg("output_sf") = std::nullopt, py::arg("workspace_") = std::nullopt, py::arg("sequence_length"),
|
||||
py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"), py::arg("context_lengths"),
|
||||
py::arg("host_context_lengths"), py::arg("host_request_types"),
|
||||
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_pool_pointers") = std::nullopt,
|
||||
py::arg("host_kv_cache_pool_mapping") = std::nullopt, py::arg("cache_indirection") = std::nullopt,
|
||||
py::arg("kv_scale_orig_quant") = std::nullopt, py::arg("kv_scale_quant_orig") = std::nullopt,
|
||||
py::arg("out_scale") = std::nullopt, py::arg("rotary_inv_freq") = std::nullopt,
|
||||
py::arg("rotary_cos_sin") = std::nullopt, py::arg("latent_cache") = std::nullopt,
|
||||
py::arg("q_pe") = std::nullopt, py::arg("block_ids_per_seq") = std::nullopt,
|
||||
py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"), py::arg("update_kv_cache"),
|
||||
py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"), py::arg("num_kv_heads"),
|
||||
py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt, py::arg("max_num_requests"),
|
||||
py::arg("max_context_length"), py::arg("attention_window_size"), py::arg("sink_token_length"),
|
||||
py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"), py::arg("q_scaling"),
|
||||
py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"), py::arg("rotary_embedding_base"),
|
||||
py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
|
||||
py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"),
|
||||
py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"),
|
||||
py::arg("chunked_prefill_buffer_batch_size") = std::nullopt, py::arg("q_lora_rank") = std::nullopt,
|
||||
py::arg("kv_lora_rank") = std::nullopt, py::arg("qk_nope_head_dim") = std::nullopt,
|
||||
py::arg("qk_rope_head_dim") = std::nullopt, py::arg("v_head_dim") = std::nullopt,
|
||||
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
|
||||
py::arg("mla_tensor_params"), py::arg("attention_chunk_size") = std::nullopt,
|
||||
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
|
||||
py::arg("spec_decoding_tensor_params"), py::arg("sparse_kv_indices") = std::nullopt,
|
||||
py::arg("sparse_kv_offsets") = std::nullopt, py::arg("sparse_attn_indices") = std::nullopt,
|
||||
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_attn_indices_block_size"),
|
||||
py::arg("sparse_mla_topk") = std::nullopt,
|
||||
py::arg("skip_softmax_threshold_scale_factor_prefill") = std::nullopt,
|
||||
py::arg("skip_softmax_threshold_scale_factor_decode") = std::nullopt,
|
||||
py::arg("skip_softmax_stat") = std::nullopt, py::arg("cu_q_seqlens") = std::nullopt,
|
||||
py::arg("cu_kv_seqlens") = std::nullopt, py::arg("fmha_scheduler_counter") = std::nullopt,
|
||||
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
|
||||
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def(
|
||||
"get_helix_workspace_size_per_rank",
|
||||
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
|
||||
py::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::thop
|
||||
@ -1,27 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::thop
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m);
|
||||
|
||||
} // namespace tensorrt_llm::pybind::thop
|
||||
@ -1,55 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
#include "tensorrt_llm/kernels/userbuffers/ub_interface.h"
|
||||
#include "tensorrt_llm/kernels/userbuffers/userbuffersManager.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tub = tensorrt_llm::runtime::ub;
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels::userbuffers
|
||||
{
|
||||
|
||||
void UserBufferBindings::initBindings(pybind11::module_& m)
|
||||
{
|
||||
py::class_<tub::UBBuffer>(m, "UBBuffer")
|
||||
.def_readonly("size", &tub::UBBuffer::size)
|
||||
.def_property_readonly("addr", [](tub::UBBuffer& self) { return reinterpret_cast<intptr_t>(self.addr); })
|
||||
.def_readonly("handle", &tub::UBBuffer::handle)
|
||||
.def("invalid", &tub::UBBuffer::invalid, py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def(
|
||||
"ub_initialize", [](int tp_size) { tub::ub_initialize(tp_size); }, py::call_guard<py::gil_scoped_release>());
|
||||
m.def("ub_is_initialized", &tub::ub_is_initialized, py::call_guard<py::gil_scoped_release>());
|
||||
m.def(
|
||||
"ub_allocate", [](size_t bytes) { return tub::ub_allocate(bytes); }, py::call_guard<py::gil_scoped_release>());
|
||||
m.def(
|
||||
"ub_deallocate", [](intptr_t addr) { return tub::ub_deallocate(reinterpret_cast<void*>(addr)); },
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
m.def("ub_get", &tub::ub_get, py::call_guard<py::gil_scoped_release>());
|
||||
m.def("ub_supported", &tub::ub_supported, py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def("initialize_userbuffers_manager", &tub::initialize_userbuffers_manager,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
} // namespace kernels::userbuffers
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -1,35 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
#include "tensorrt_llm/pybind/common/customCasters.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels::userbuffers
|
||||
{
|
||||
class UserBufferBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace kernels::userbuffers
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
Loading…
Reference in New Issue
Block a user