[None][chore] Removing cpp/tensorrt_llm/pybind (#11026)

Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
This commit is contained in:
Linda 2026-01-28 11:25:11 +01:00 committed by GitHub
parent 38bcee189c
commit 29647d9446
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 0 additions and 6711 deletions

View File

@ -1,77 +0,0 @@
set(TRTLLM_PYBIND_MODULE bindings)
set(TRTLLM_PYBIND_MODULE
${TRTLLM_PYBIND_MODULE}
PARENT_SCOPE)
set(SRCS
batch_manager/algorithms.cpp
batch_manager/bindings.cpp
batch_manager/buffers.cpp
batch_manager/cacheTransceiver.cpp
batch_manager/kvCacheConnector.cpp
batch_manager/kvCacheManager.cpp
batch_manager/kvCacheManagerV2Utils.cpp
batch_manager/llmRequest.cpp
executor/bindings.cpp
executor/executor.cpp
executor/executorConfig.cpp
executor/request.cpp
process_group/bindings.cpp
runtime/bindings.cpp
runtime/hostfunc.cpp
common/tllmExceptions.cpp
testing/modelSpecBinding.cpp
runtime/moeBindings.cpp
testing/modelSpecBinding.cpp
userbuffers/bindings.cpp
../runtime/ipcNvlsMemory.cu
thop/bindings.cpp
bindings.cpp)
include_directories(${PROJECT_SOURCE_DIR}/include)
pybind11_add_module(${TRTLLM_PYBIND_MODULE} ${SRCS})
set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE
ON)
target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC
"${TORCH_INSTALL_PREFIX}/lib")
if(ENABLE_NVSHMEM)
target_link_libraries(${TRTLLM_PYBIND_MODULE} PUBLIC nvshmem::nvshmem_host
nvshmem::nvshmem_device)
endif()
target_link_libraries(
${TRTLLM_PYBIND_MODULE}
PUBLIC ${SHARED_TARGET}
${Python3_LIBRARIES}
${TORCH_LIBRARIES}
torch_python
${CUDA_DRV_LIB}
${CUDA_NVML_LIB}
th_common
pg_utils)
target_compile_definitions(
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
PYBIND11_DETAILED_ERROR_MESSAGES=1)
if(NOT WIN32)
set(TRTLLM_PYBIND_MODULE_RPATH_LIST
"$ORIGIN/libs" # TRTLLM libraries
"$ORIGIN/../../.." # Shared libraries under $PREFIX/lib
"$ORIGIN/../nvidia/nccl/lib" # NCCL libraries
)
set_target_properties(
${TRTLLM_PYBIND_MODULE} PROPERTIES BUILD_RPATH
"${TRTLLM_PYBIND_MODULE_RPATH_LIST}")
set_target_properties(
${TRTLLM_PYBIND_MODULE} PROPERTIES LINK_FLAGS
"${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
endif()
# Build transfer_agent_binding when building bindings (if NIXL is enabled)
if(TARGET ${TRANSFER_AGENT_BINDING_TARGET})
add_dependencies(${TRTLLM_PYBIND_MODULE} ${TRANSFER_AGENT_BINDING_TARGET})
endif()

View File

@ -1,127 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "algorithms.h"
#include "tensorrt_llm/batch_manager/allocateKvCache.h"
#include "tensorrt_llm/batch_manager/assignReqSeqSlots.h"
#include "tensorrt_llm/batch_manager/capacityScheduler.h"
#include "tensorrt_llm/batch_manager/createNewDecoderRequests.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/logitsPostProcessor.h"
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
#include "tensorrt_llm/batch_manager/pauseRequests.h"
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <ATen/core/TensorBody.h>
#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <optional>
namespace py = pybind11;
namespace tr = tensorrt_llm::runtime;
using namespace tensorrt_llm::batch_manager;
void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::module_& m)
{
py::class_<CapacityScheduler>(m, CapacityScheduler::name)
.def(py::init<SizeType32, executor::CapacitySchedulerPolicy, bool, bool, LlmRequestState, LlmRequestState>(),
py::arg("max_num_requests"), py::arg("capacity_scheduler_policy"), py::arg("has_kv_cache_manager"),
py::arg("two_step_lookahead") = false,
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
"LlmRequestState.GENERATION_COMPLETE"))
.def("__call__", &CapacityScheduler::operator(), py::arg("active_requests"),
py::arg("kv_cache_manager") = nullptr, py::arg("peft_cache_manager") = nullptr,
py::arg("cross_kv_cache_manager") = nullptr)
.def("name", [](CapacityScheduler const&) { return CapacityScheduler::name; });
py::class_<MicroBatchScheduler>(m, MicroBatchScheduler::name)
.def(py::init<std::optional<batch_scheduler::ContextChunkingConfig>, std::optional<SizeType32>, LlmRequestState,
LlmRequestState>(),
py::arg("ctx_chunk_config") = std::nullopt, py::arg("max_context_length") = std::nullopt,
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_TO_COMPLETE,
"LlmRequestState.GENERATION_TO_COMPLETE"))
.def("__call__", &MicroBatchScheduler::operator(), py::arg("active_requests"), py::arg("inflight_req_ids"),
py::arg("max_batch_size_runtime"), py::arg("max_num_tokens_runtime"))
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });
py::class_<PauseRequests>(m, PauseRequests::name)
.def(py::init<SizeType32>(), py::arg("max_input_len"))
.def("__call__", &PauseRequests::operator(), py::arg("requests_to_pause"), py::arg("inflight_req_ids"),
py::arg("req_ids_to_pause"), py::arg("pause_flagged"), py::arg("seq_slot_manager"),
py::arg("kv_cache_manager") = std::nullopt, py::arg("cross_kv_cache_manager") = std::nullopt,
py::arg("peft_cache_manager") = std::nullopt)
.def("name", [](PauseRequests const&) { return PauseRequests::name; });
py::class_<AssignReqSeqSlots>(m, AssignReqSeqSlots::name)
.def(py::init())
.def("__call__", &AssignReqSeqSlots::operator(), py::arg("seq_slot_manager"), py::arg("context_requests"),
py::arg("generation_requests"))
.def("name", [](AssignReqSeqSlots const&) { return AssignReqSeqSlots::name; });
py::class_<AllocateKvCache>(m, AllocateKvCache::name)
.def(py::init(), py::call_guard<py::gil_scoped_release>())
.def("__call__", &AllocateKvCache::operator(), py::arg("kv_cache_manager"), py::arg("context_requests"),
py::arg("generation_requests"), py::arg("model_config"), py::arg("cross_kv_cache_manager") = std::nullopt,
py::call_guard<py::gil_scoped_release>())
.def("name", [](AllocateKvCache const&) { return AllocateKvCache::name; });
py::class_<LogitsPostProcessor>(m, LogitsPostProcessor::name)
.def(py::init())
.def("__call__", &LogitsPostProcessor::operator(), py::arg("decoder_input_buffers"),
py::arg("replicate_logits_post_processor"), py::arg("world_config"), py::arg("stream"),
py::arg("logits_post_processor_batched") = std::nullopt)
.def("name", [](LogitsPostProcessor const&) { return LogitsPostProcessor::name; });
py::class_<CreateNewDecoderRequests>(m, CreateNewDecoderRequests::name)
.def(py::init<bool, bool, bool>(), py::arg("speculative_decoding_fast_logits"),
py::arg("is_leader_in_orch_mode"), py::arg("is_normalize_log_probs"))
.def(
"__call__",
[](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
runtime::decoder::DecoderState& decoderState, tensorrt_llm::runtime::CudaStream const& runtimeStream,
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
SizeType32 beamWidth)
{
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
= self(modelConfig, worldConfig, decodingConfig, contextRequests, logitsType, inputBuffers,
decoderState, runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs),
std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
},
py::arg("model_config"), py::arg("world_config"), py::arg("decoding_config"), py::arg("context_requests"),
py::arg("logits_type"), py::arg("decoder_input_buffers"), py::arg("decoder_state"),
py::arg("runtime_stream"), py::arg("decoder_stream"), py::arg("max_sequence_length"), py::arg("beam_width"))
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
}

View File

@ -1,28 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::batch_manager::algorithms
{
void initBindings(pybind11::module_& m);
}

View File

@ -1,517 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
#include "tensorrt_llm/pybind/common/bindTypes.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <ATen/ATen.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
#include <tuple>
namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
namespace tle = tensorrt_llm::executor;
namespace tr = tensorrt_llm::runtime;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm::pybind::batch_manager
{
void initBindings(pybind11::module_& m)
{
using GenLlmReq = tb::GenericLlmRequest<runtime::ITensor::SharedPtr>;
// Create and register exceptions in module scope
static PyObject* peft_exc = PyErr_NewException(
"tensorrt_llm.bindings.internal.batch_manager.PeftTaskNotCachedException", nullptr, nullptr);
static PyObject* lora_exc
= PyErr_NewException("tensorrt_llm.bindings.internal.batch_manager.LoraCacheFullException", nullptr, nullptr);
m.add_object("PeftTaskNotCachedException", py::handle(peft_exc));
m.add_object("LoraCacheFullException", py::handle(lora_exc));
// Register with no captures
py::register_exception_translator(
[](std::exception_ptr p)
{
try
{
if (p)
std::rethrow_exception(p);
}
catch (const tb::PeftTaskNotCachedException& e)
{
PyErr_SetString(peft_exc, e.what());
}
catch (const tr::LoraCacheFullException& e)
{
PyErr_SetString(lora_exc, e.what());
}
});
PybindUtils::bindSet<tb::ReqIdsSet>(m, "ReqIdsSet");
py::enum_<tb::LlmRequestType>(m, "LlmRequestType")
.value("LLMREQUEST_TYPE_CONTEXT_AND_GENERATION", tb::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION)
.value("LLMREQUEST_TYPE_CONTEXT_ONLY", tb::LLMREQUEST_TYPE_CONTEXT_ONLY)
.value("LLMREQUEST_TYPE_GENERATION_ONLY", tb::LLMREQUEST_TYPE_GENERATION_ONLY)
.export_values();
py::class_<tb::batch_scheduler::ContextChunkingConfig>(m, "ContextChunkingConfig")
.def(py::init<tle::ContextChunkingPolicy, tensorrt_llm::runtime::SizeType32>(), py::arg("chunking_policy"),
py::arg("chunk_unit_size"))
.def_readwrite("chunking_policy", &tb::batch_scheduler::ContextChunkingConfig::chunkingPolicy)
.def_readwrite("chunk_unit_size", &tb::batch_scheduler::ContextChunkingConfig::chunkUnitSize);
py::classh<GenLlmReq>(m, "GenericLlmRequest")
.def("set_exclude_input_from_output", &GenLlmReq::setExcludeInputFromOutput, py::arg("exclude"))
.def("get_num_tokens", &GenLlmReq::getNumTokens, py::arg("beam"))
.def_property_readonly("max_beam_num_tokens", &GenLlmReq::getMaxBeamNumTokens)
.def("get_token", &GenLlmReq::getToken, py::arg("beam"), py::arg("pos"))
.def("get_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getTokens, py::const_), py::arg("beam"))
.def("get_tokens", py::overload_cast<>(&GenLlmReq::getTokens, py::const_))
.def("get_last_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getLastTokens), py::arg("beam"))
.def("get_last_tokens", py::overload_cast<>(&GenLlmReq::getLastTokens))
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, py::arg("for_next_iteration") = false)
.def_property_readonly("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
.def("add_new_token", &GenLlmReq::addNewToken, py::arg("token"), py::arg("beam"))
.def("add_new_tokens", &GenLlmReq::addNewTokens, py::arg("beam_tokens"))
.def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens)
.def("set_generated_tokens", &GenLlmReq::setGeneratedTokens, py::arg("generated_beam_tokens"))
.def("pause", &GenLlmReq::pause, py::arg("max_input_len"))
.def_property("max_sent_token_len", &GenLlmReq::getMaxSentTokenLen, &GenLlmReq::setMaxSentTokenLen)
.def_property_readonly("prompt_embedding_table", &GenLlmReq::getPromptEmbeddingTable)
.def_property_readonly("multimodal_embedding", &GenLlmReq::getMultimodalEmbedding)
.def_property_readonly("mrope_rotary_cos_sin", &GenLlmReq::getMropeRotaryCosSin)
.def_property_readonly("bad_words_list", &GenLlmReq::getBadWordsList)
.def_property("draft_logits", &GenLlmReq::getDraftLogits, &GenLlmReq::setDraftLogits)
.def_property_readonly("embedding_bias", &GenLlmReq::getEmbeddingBias)
.def_property("lora_config", &GenLlmReq::getLoraConfig, &GenLlmReq::setLoraConfig)
.def_property("lora_weights", &GenLlmReq::getLoraWeights, &GenLlmReq::setLoraWeights)
.def_property_readonly("stop_words_list", &GenLlmReq::getStopWordsList)
.def_property_readonly("context_logits", &GenLlmReq::getContextLogitsHost)
.def_property_readonly("generation_logits", &GenLlmReq::getGenerationLogitsHost)
.def_property_readonly("prompt_vocab_size", &GenLlmReq::getPromptVocabSize)
.def_property_readonly("mrope_position_deltas", &GenLlmReq::getMropePositionDeltas)
.def_property_readonly("lora_task_id", &GenLlmReq::getLoraTaskId)
.def_property_readonly("lookahead_config", &GenLlmReq::getLookaheadConfig)
.def_property("context_chunk_size", &GenLlmReq::getContextChunkSize, &GenLlmReq::setContextChunkSize)
.def_property("decoding_iter", &GenLlmReq::getDecodingIter, &GenLlmReq::setDecodingIter)
.def_readwrite("request_id", &GenLlmReq::mRequestId)
.def_readwrite("prompt_len", &GenLlmReq::mPromptLen)
.def_readwrite("max_new_tokens", &GenLlmReq::mMaxNewTokens)
.def_readwrite("sampling_config", &GenLlmReq::mSamplingConfig)
.def_property("state", &GenLlmReq::getState, &GenLlmReq::setState)
.def_property_readonly("state_value", [](GenLlmReq const& self) { return static_cast<int>(self.getState()); })
.def_property("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming)
.def_readwrite("end_id", &GenLlmReq::mEndId)
.def_readwrite("pad_id", &GenLlmReq::mPadId)
.def_readwrite("seq_slot", &GenLlmReq::mSeqSlot)
.def_property_readonly("return_log_probs", &GenLlmReq::returnLogProbs)
.def_property_readonly("return_context_logits", &GenLlmReq::getReturnContextLogits)
.def_property_readonly("return_generation_logits", &GenLlmReq::getReturnGenerationLogits)
.def_property_readonly("log_probs", py::overload_cast<>(&GenLlmReq::getLogProbs, py::const_))
.def("get_log_probs", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getLogProbs, py::const_))
.def("set_log_probs", &GenLlmReq::setLogProbs, py::arg("log_probs"), py::arg("beam"))
.def("set_return_encoder_output", &GenLlmReq::setReturnEncoderOutput, py::arg("return_encoder_output"))
.def("get_return_encoder_output", &GenLlmReq::getReturnEncoderOutput)
.def("priority", py::overload_cast<>(&GenLlmReq::priority, py::const_))
.def("set_priority", py::overload_cast<tle::PriorityType>(&GenLlmReq::setPriority))
.def_property_readonly("cum_log_probs", &GenLlmReq::getCumLogProbs)
.def("set_cum_log_prob", &GenLlmReq::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
.def("update_num_tokens_per_iteration", &GenLlmReq::updateNumTokensPerIteration,
py::arg("num_tokens_per_iteration"), py::arg("model_config"))
.def_property_readonly("orig_prompt_len", &GenLlmReq::getOrigPromptLen)
.def("has_draft_tokens", &GenLlmReq::hasDraftTokens)
.def("move_to_next_context_chunk", &GenLlmReq::moveToNextContextChunk)
.def_property_readonly("is_last_context_chunk", &GenLlmReq::isLastContextChunk)
.def_property_readonly("is_first_context_chunk", &GenLlmReq::isFirstContextChunk)
.def_property_readonly("context_remaining_length", &GenLlmReq::getContextRemainingLength)
.def_property_readonly("context_logits", &GenLlmReq::getContextLogitsHost)
.def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens)
.def("set_finished_reason", &GenLlmReq::setFinishedReason, py::arg("finish_reason"), py::arg("beam"))
.def_property_readonly("is_finished", &GenLlmReq::isFinished)
.def_property_readonly("is_finished_due_to_length", &GenLlmReq::isFinishedDueToLength)
.def_property_readonly("is_finished_due_to_cancellation", &GenLlmReq::isFinishedDueToCancellation)
.def_property(
"context_current_position", &GenLlmReq::getContextCurrentPosition, &GenLlmReq::setContextCurrentPosition)
.def_property_readonly("prepopulated_prompt_len", &GenLlmReq::getPrepopulatedPromptLen)
.def_property(
"guided_decoding_params", &GenLlmReq::getGuidedDecodingParams, &GenLlmReq::setGuidedDecodingParams)
.def_property("context_phase_params", &GenLlmReq::getContextPhaseParams, &GenLlmReq::setContextPhaseParams)
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
.def_property_readonly("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
.def_property_readonly("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
.def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished)
.def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)
.def_property_readonly(
"is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
.def_property_readonly(
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
.def_property_readonly("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
.def_property_readonly("is_context_init_state", &GenLlmReq::isContextInitState)
.def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
.def_property_readonly("is_disagg_context_complete_state", &GenLlmReq::isDisaggContextCompleteState)
.def_property_readonly("stage", &GenLlmReq::getRequestStage)
.def_property_readonly("kv_cache_transfer_time_ms", &GenLlmReq::getKvCacheTransferTimeMS)
.def_property_readonly("kv_cache_size", &GenLlmReq::getKvCacheSize)
.def_property_readonly("avg_decoded_tokens_per_iter", &GenLlmReq::getAvgDecodedTokensPerIter)
.def_property_readonly("alloc_total_blocks", &GenLlmReq::getAllocTotalBlocksPerRequest)
.def_property_readonly("alloc_new_blocks", &GenLlmReq::getAllocNewBlocksPerRequest)
.def("alloc_context_logits", &GenLlmReq::allocContextLogitsHost, py::arg("vocab_size"), py::arg("logit_dtype"))
.def_property_readonly("reused_blocks", &GenLlmReq::getReusedBlocksPerRequest)
.def_property_readonly("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest)
.def_property_readonly("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest)
.def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType)
.def_property_readonly("parent_request_id", &GenLlmReq::getParentRequestId)
.def_property_readonly("is_child", &GenLlmReq::isChild)
.def_property_readonly("cache_salt_id", &GenLlmReq::getCacheSaltID)
.def_property_readonly("multimodal_hashes",
[](GenLlmReq& self)
{
std::optional<std::vector<std::vector<GenLlmReq::SizeType32>>> hashes = std::nullopt;
if (self.getMultimodalHashes())
{
hashes = *self.getMultimodalHashes().value();
}
return hashes;
})
.def_property_readonly("multimodal_positions",
[](GenLlmReq& self)
{
std::optional<std::vector<GenLlmReq::SizeType32>> positions = std::nullopt;
if (self.getMultimodalPositions())
{
positions = *self.getMultimodalPositions().value();
}
return positions;
})
.def_property_readonly("multimodal_lengths",
[](GenLlmReq& self)
{
std::optional<std::vector<GenLlmReq::SizeType32>> lengths = std::nullopt;
if (self.getMultimodalLengths())
{
lengths = *self.getMultimodalLengths().value();
}
return lengths;
})
.def_property_readonly("position_ids",
[](GenLlmReq& self)
{
std::optional<std::vector<GenLlmReq::SizeType32>> positionIds = std::nullopt;
if (self.getPositionIds())
{
positionIds = *self.getPositionIds().value();
}
return positionIds;
})
.def_property(
"draft_tokens",
[](GenLlmReq& self)
{
std::optional<GenLlmReq::VecTokens> draftTokens = std::nullopt;
if (self.hasDraftTokens())
{
draftTokens = *self.getDraftTokens();
}
return draftTokens;
},
[](GenLlmReq& self, std::optional<GenLlmReq::VecTokens> const& draftTokens)
{
if (draftTokens)
{
self.setDraftTokens(std::make_shared<GenLlmReq::VecTokens>(draftTokens.value()));
}
})
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
.def("get_unique_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, py::const_),
py::arg("beam"))
.def("get_unique_tokens", py::overload_cast<>(&GenLlmReq::getUniqueTokens, py::const_))
.def("get_encoder_unique_tokens",
[](GenLlmReq& self)
{
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
{
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
}
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
});
py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
.def(py::init<>(
[](tb::LlmRequest::RequestIdType request_id, tb::LlmRequest::SizeType32 max_new_tokens,
std::vector<tb::LlmRequest::TokenIdType> input_tokens, runtime::SamplingConfig sampling_config,
bool is_streaming, std::optional<tb::LlmRequest::SizeType32> end_id,
std::optional<tb::LlmRequest::SizeType32> pad_id, std::optional<at::Tensor> embedding_bias,
std::optional<at::Tensor> bad_words_list, std::optional<at::Tensor> stop_words_list,
std::optional<std::vector<tb::LlmRequest::SizeType32>> position_ids,
std::optional<at::Tensor> prompt_embedding_table,
std::optional<tb::LlmRequest::SizeType32> prompt_vocab_size,
std::optional<std::vector<std::vector<tb::LlmRequest::SizeType32>>> multimodal_hashes,
std::optional<std::vector<tb::LlmRequest::SizeType32>> multimodal_positions,
std::optional<std::vector<tb::LlmRequest::SizeType32>> multimodal_lengths,
std::optional<at::Tensor> multimodal_embedding, std::optional<at::Tensor> mrope_rotary_cos_sin,
std::optional<tb::LlmRequest::SizeType32> mrope_position_deltas,
std::optional<LoraTaskIdType> lora_task_id, std::optional<at::Tensor> lora_weights,
std::optional<at::Tensor> lora_config,
std::optional<executor::LookaheadDecodingConfig> lookahead_config,
std::optional<executor::KvCacheRetentionConfig> kv_cache_retention_config, bool return_log_probs,
bool return_context_logits, bool return_generation_logits,
std::optional<tb::LlmRequest::VecTokens> draft_tokens, std::optional<at::Tensor> draft_logits,
bool exclude_input_from_output,
std::optional<tb::LlmRequest::LogitsPostProcessor> logits_post_processor,
bool apply_logits_post_processor_batched,
std::optional<tb::LlmRequest::VecTokens> encoder_input_tokens, bool return_encoder_output,
std::optional<tb::LlmRequest::RequestIdType> client_id, executor::PriorityType priority,
std::optional<at::Tensor> encoder_input_features,
std::optional<tb::LlmRequest::SizeType32> encoder_output_length,
std::optional<at::Tensor> cross_attention_mask, tb::LlmRequestType llm_request_type,
std::optional<tb::LlmRequest::VecTokenExtraIds> input_token_extra_ids,
tb::LlmRequest::SizeType32 num_return_sequences, std::optional<executor::EagleConfig> eagle_config,
std::optional<at::Tensor> skip_cross_attn_blocks, bool return_perf_metrics,
std::optional<executor::GuidedDecodingParams> guided_decoding_params,
std::optional<tb::LlmRequest::SizeType32> language_adapter_uid,
std::optional<tb::LlmRequest::MillisecondsType> allotted_time_ms,
std::optional<executor::ContextPhaseParams> context_phase_params,
std::optional<tb::LlmRequest::CacheSaltIDType> cache_salt_id,
std::optional<tb::LlmRequest::TimePoint> arrival_time)
{
auto makeOptionalTensor = [](std::optional<at::Tensor> const& atTensor, bool unsqueeze = false)
{
std::optional<tb::LlmRequest::TensorPtr> tensorPtr = std::nullopt;
if (atTensor)
{
tensorPtr = tr::TorchView::of(atTensor.value());
if (unsqueeze)
{
(*tensorPtr)->unsqueeze(0);
}
}
return tensorPtr;
};
auto embedding_bias_tensor_ptr = makeOptionalTensor(embedding_bias, true);
auto bad_words_list_tensor_ptr = makeOptionalTensor(bad_words_list, true);
auto stop_words_list_tensor_ptr = makeOptionalTensor(stop_words_list, true);
auto prompt_embedding_table_tensor_ptr = makeOptionalTensor(prompt_embedding_table);
auto multimodal_embedding_tensor_ptr = makeOptionalTensor(multimodal_embedding);
auto lora_weights_tensor_ptr = makeOptionalTensor(lora_weights);
auto mrope_rotary_cos_sin_tensor_ptr = makeOptionalTensor(mrope_rotary_cos_sin);
auto lora_config_tensor_ptr = makeOptionalTensor(lora_config);
auto draft_logits_tensor_ptr = makeOptionalTensor(draft_logits);
auto encoder_input_features_tensor_ptr = makeOptionalTensor(encoder_input_features);
auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask);
auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks);
return tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming,
end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr,
stop_words_list_tensor_ptr, position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size,
multimodal_hashes, multimodal_positions, multimodal_lengths, multimodal_embedding_tensor_ptr,
mrope_rotary_cos_sin_tensor_ptr, mrope_position_deltas, lora_task_id, lora_weights_tensor_ptr,
lora_config_tensor_ptr, lookahead_config, kv_cache_retention_config, return_log_probs,
return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr,
exclude_input_from_output, logits_post_processor, apply_logits_post_processor_batched,
encoder_input_tokens, return_encoder_output, client_id, priority,
encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr,
llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config,
skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params,
language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id, arrival_time};
}),
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
py::arg("stop_words_list") = std::nullopt, py::arg("position_ids") = std::nullopt,
py::arg("prompt_embedding_table") = std::nullopt, py::arg("prompt_vocab_size") = std::nullopt,
py::arg("multimodal_hashes") = std::nullopt, py::arg("multimodal_positions") = std::nullopt,
py::arg("multimodal_lengths") = std::nullopt, py::arg("multimodal_embedding") = std::nullopt,
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
py::arg("lora_task_id") = std::nullopt, py::arg("lora_weights") = std::nullopt,
py::arg("lora_config") = std::nullopt, py::arg("lookahead_config") = std::nullopt,
py::arg("kv_cache_retention_config") = std::nullopt, py::arg("return_log_probs") = false,
py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false,
py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt,
py::arg("exclude_input_from_output") = false, py::arg("logits_post_processor") = std::nullopt,
py::arg("apply_logits_post_processor_batched") = false, py::arg("encoder_input_tokens") = std::nullopt,
py::arg("return_encoder_output") = false, py::arg("client_id") = std::nullopt,
py::arg("priority") = executor::Request::kDefaultPriority, py::arg("encoder_input_features") = std::nullopt,
py::arg("encoder_output_len") = std::nullopt, py::arg("cross_attention_mask") = std::nullopt,
py::arg_v("llm_request_type", tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
"LlmRequestType.LLMREQUEST_TYPE_CONTEXT_AND_GENERATION"),
py::arg("input_token_extra_ids") = std::nullopt, py::arg("num_return_sequences") = 1,
py::arg("eagle_config") = std::nullopt, py::arg("skip_cross_attn_blocks") = std::nullopt,
py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt,
py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt,
py::arg("context_phase_params") = std::nullopt, py::arg("cache_salt_id") = std::nullopt,
py::arg("arrival_time") = std::nullopt)
.def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, py::arg("vocab_size"))
.def(py::init<tb::LlmRequest const&>())
.def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"),
py::arg("max_draft_len"), py::arg("vocab_size_padded"), py::arg("max_endocer_input_len") = std::nullopt,
py::arg("enable_kv_cache_reuse") = false)
.def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false,
py::arg("mpi_world_rank") = 0)
.def("create_child_request", &tb::LlmRequest::createChildRequest, py::arg("child_id"))
.def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false,
py::arg("mpi_world_rank") = 0)
.def("create_serialized_result",
[](tb::LlmRequest& self, bool use_fast_logits = false, int mpi_world_rank = 0)
{
std::vector<char> serialized_result;
bool is_final = false;
self.createSerializedResult(serialized_result, is_final, use_fast_logits, mpi_world_rank);
return std::make_tuple(py::bytes(serialized_result.data(), serialized_result.size()), is_final);
})
.def("move_prompt_embedding_table_to_gpu", &tb::LlmRequest::movePromptEmbeddingTableToGpu, py::arg("manager"))
.def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager"))
.def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason"))
.def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime)
.def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter"))
.def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors)
.def_readwrite_static("global_steady_clock_offset", &tb::LlmRequest::sGlobalSteadyClockOffset);
py::classh<tb::SequenceSlotManager>(m, "SequenceSlotManager")
.def(py::init<tb::SequenceSlotManager::SlotIdType, uint64_t>(), py::arg("max_num_slots"),
py::arg("max_sequence_idle_microseconds"))
.def("get_sequence_slot", &tb::SequenceSlotManager::getSequenceSlot, py::arg("start_flag"),
py::arg("sequence_id"))
.def("free_sequence_slot", &tb::SequenceSlotManager::freeSequenceSlot, py::arg("sequence_id"))
.def("free_idle_sequence_slots", &tb::SequenceSlotManager::freeIdleSequenceSlots);
py::classh<tb::rnn_state_manager::RnnStateManager>(m, "RnnStateManager")
.def(py::init<tr::SizeType32, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
py::arg("max_num_sequences"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"));
m.def(
"add_new_tokens_to_requests",
[](std::vector<std::shared_ptr<tb::LlmRequest>>& requests,
std::vector<tb::LlmRequest::TokenIdType> const& tokens, int beam_idx)
{
TLLM_CHECK_WITH_INFO(requests.size() == tokens.size(), "Expected the same number of requests and tokens.");
for (int i = 0; i < requests.size(); ++i)
{
requests[i]->addNewToken(tokens[i], beam_idx);
}
},
py::arg("requests"), py::arg("tokens"), py::arg("beam_idx"),
"Add new tokens to multiple LLM requests. The tokens vector should contain tokens for beam beam_idx of all "
"requests in order.");
m.def(
"make_decoding_batch_input",
[](tb::DecoderInputBuffers& decoderInputBuffers, runtime::decoder::DecoderState& decoderState,
std::vector<std::shared_ptr<tb::LlmRequest>> const& contextRequests,
std::vector<std::shared_ptr<tb::LlmRequest>> const& genRequests, tr::ITensor::SharedPtr const& logits,
int beamWidth, std::vector<int> const& numContextLogitsPrefixSum, tr::BufferManager const& manager)
{
std::vector<int> activeSlots;
std::vector<int> generationSteps;
std::vector<std::vector<tr::ITensor::SharedConstPtr>> logitsVec = {{}};
for (int i = 0; i < contextRequests.size(); ++i)
{
if (contextRequests[i]->isLastContextChunk())
{
activeSlots.push_back(*contextRequests[i]->mSeqSlot);
generationSteps.push_back(contextRequests[i]->getDecodingIter());
auto contextLogitsOffset = numContextLogitsPrefixSum[i + 1] - 1;
tr::ITensor::SharedPtr logitsView = ITensor::slice(logits, contextLogitsOffset, 1);
if (beamWidth > 1)
{
// Tile logits of context requests
auto const logitsShape = logitsView->getShape();
auto const logitsType = logitsView->getDataType();
auto decoderLogits = manager.gpu(ITensor::makeShape({beamWidth, logitsShape.d[1]}), logitsType);
tensorrt_llm::runtime::kernels::tileTensor(
*decoderLogits, *logitsView, beamWidth, manager.getStream());
decoderLogits->unsqueeze(0);
logitsVec[0].push_back(std::move(decoderLogits));
}
else
{
logitsView->unsqueeze(1);
logitsVec[0].push_back(std::move(logitsView));
}
}
}
auto genLogitsOffset = numContextLogitsPrefixSum.back();
for (int i = 0; i < genRequests.size(); ++i)
{
if (genRequests[i]->isGenerationInProgressState())
{
activeSlots.push_back(*genRequests[i]->mSeqSlot);
generationSteps.push_back(genRequests[i]->getDecodingIter());
auto logitsOffset = genLogitsOffset + i * beamWidth;
auto numberOfLogits = beamWidth;
tr::ITensor::SharedPtr logitsView = ITensor::slice(logits, logitsOffset, numberOfLogits);
logitsView->unsqueeze(0);
logitsVec[0].push_back(std::move(logitsView));
}
}
auto& batchSlots = decoderInputBuffers.forwardBatchSlots;
batchSlots[0]->resize(activeSlots.size());
auto batchSlotsRange = tr::BufferRange<SizeType32>(*batchSlots[0]);
for (int i = 0; i < activeSlots.size(); ++i)
{
batchSlotsRange[i] = activeSlots[i];
}
decoderInputBuffers.batchLogits = logitsVec;
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
if (maxBeamWidth > 1)
{
// For Variable-Beam-Width-Search
decoderState.getJointDecodingInput().generationSteps = generationSteps;
}
},
py::arg("decoder_input_buffers"), py::arg("decoder_state"), py::arg("context_requests"),
py::arg("generation_requests"), py::arg("logits"), py::arg("beam_width"),
py::arg("num_context_logits_prefix_sum"), py::arg("buffer_manager"), "Make decoding batch input.");
}
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -1,28 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::batch_manager
{
void initBindings(pybind11::module_& m);
}

View File

@ -1,75 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "buffers.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/runtime/torch.h"
#include <ATen/ATen.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
namespace tr = tensorrt_llm::runtime;
using tr::SizeType32;
namespace tensorrt_llm::pybind::batch_manager
{
void Buffers::initBindings(pybind11::module_& m)
{
py::class_<tb::DecoderInputBuffers>(m, "DecoderInputBuffers")
.def(py::init<tr::SizeType32, tr::SizeType32, tr::BufferManager>(), py::arg("max_batch_size"),
py::arg("max_tokens_per_engine_step"), py::arg("manager"))
.def_readwrite("setup_batch_slots", &tb::DecoderInputBuffers::setupBatchSlots)
.def_readwrite("setup_batch_slots_device", &tb::DecoderInputBuffers::setupBatchSlotsDevice)
.def_readwrite("fill_values", &tb::DecoderInputBuffers::fillValues)
.def_readwrite("fill_values_device", &tb::DecoderInputBuffers::fillValuesDevice)
.def_readwrite("inputs_ids", &tb::DecoderInputBuffers::inputsIds)
.def_readwrite("batch_logits", &tb::DecoderInputBuffers::batchLogits)
.def_readwrite("forward_batch_slots", &tb::DecoderInputBuffers::forwardBatchSlots)
.def_readwrite("decoder_logits", &tb::DecoderInputBuffers::decoderLogits)
.def_readwrite("max_decoder_steps", &tb::DecoderInputBuffers::maxDecoderSteps);
py::class_<tb::DecoderOutputBuffers>(m, "DecoderOutputBuffers")
.def_readwrite("sequence_lengths_host", &tb::DecoderOutputBuffers::sequenceLengthsHost)
.def_readwrite("finished_sum_host", &tb::DecoderOutputBuffers::finishedSumHost)
.def_property_readonly("new_output_tokens_host",
[](tb::DecoderOutputBuffers& self) { return tr::Torch::tensor(self.newOutputTokensHost); })
.def_readwrite("cum_log_probs_host", &tb::DecoderOutputBuffers::cumLogProbsHost)
.def_readwrite("log_probs_host", &tb::DecoderOutputBuffers::logProbsHost)
.def_readwrite("finish_reasons_host", &tb::DecoderOutputBuffers::finishReasonsHost);
py::class_<tb::SlotDecoderBuffers>(m, "SlotDecoderBuffers")
.def(py::init<runtime::SizeType32, runtime::SizeType32, runtime::BufferManager const&>(),
py::arg("max_beam_width"), py::arg("max_seq_len"), py::arg("buffer_manager"))
.def_readwrite("output_ids", &tb::SlotDecoderBuffers::outputIds)
.def_readwrite("output_ids_host", &tb::SlotDecoderBuffers::outputIdsHost)
.def_readwrite("sequence_lengths_host", &tb::SlotDecoderBuffers::sequenceLengthsHost)
.def_readwrite("cum_log_probs", &tb::SlotDecoderBuffers::cumLogProbs)
.def_readwrite("cum_log_probs_host", &tb::SlotDecoderBuffers::cumLogProbsHost)
.def_readwrite("log_probs", &tb::SlotDecoderBuffers::logProbs)
.def_readwrite("log_probs_host", &tb::SlotDecoderBuffers::logProbsHost)
.def_readwrite("finish_reasons_host", &tb::SlotDecoderBuffers::finishReasonsHost);
}
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -1,30 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::batch_manager
{
class Buffers
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -1,204 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/bindingUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include <ATen/ATen.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
#include <typeinfo>
using SizeType32 = tensorrt_llm::runtime::SizeType32;
namespace tb = tensorrt_llm::batch_manager;
namespace
{
class PyCacheTransceiver : public tb::BaseCacheTransceiver
{
public:
// using BaseCacheTransceiver::BaseCacheTransceiver; // Inherit constructors
void respondAndSendAsync(tb::LlmRequest* llmRequest) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, respondAndSendAsync, llmRequest);
}
void requestAndReceiveSync(tb::LlmRequest* llmRequest) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveSync, llmRequest);
}
void requestAndReceiveAsync(tb::LlmRequest* llmRequest) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, requestAndReceiveAsync, llmRequest);
}
using RequestStatuses = tb::RequestStatuses;
RequestStatuses checkContextTransferStatus(
std::optional<int> const& atLeastRequestNum = std::nullopt, bool markComplete = false) override
{
PYBIND11_OVERLOAD_PURE(
RequestStatuses, tb::BaseCacheTransceiver, checkContextTransferStatus, atLeastRequestNum, markComplete);
}
void checkGenTransferStatus(std::optional<int> const& atLeastRequestNum = std::nullopt) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BaseCacheTransceiver, checkGenTransferStatus, atLeastRequestNum);
}
bool checkGenTransferComplete() const override
{
PYBIND11_OVERLOAD_PURE(bool, tb::BaseCacheTransceiver, checkGenTransferComplete);
}
bool cancelRequest(tb::LlmRequest* llmRequest) override
{
PYBIND11_OVERLOAD_PURE(bool, tb::BaseCacheTransceiver, cancelRequest, llmRequest);
}
};
} // namespace
void tb::CacheTransceiverBindings::initBindings(py::module_& m)
{
py::classh<tb::BaseCacheTransceiver, PyCacheTransceiver>(m, "BaseCacheTransceiver")
.def("respond_and_send_async", &BaseCacheTransceiver::respondAndSendAsync)
.def("request_and_receive_sync", &BaseCacheTransceiver::requestAndReceiveSync)
.def("request_and_receive_async", &BaseCacheTransceiver::requestAndReceiveAsync)
.def(
"check_context_transfer_status",
[](tb::BaseCacheTransceiver& self, std::optional<int> const& atLeastRequestNum, bool markComplete = false)
{
RequestStatuses result;
{
py::gil_scoped_release release;
result = self.checkContextTransferStatus(atLeastRequestNum, markComplete);
}
auto completedRequestIds
= std::vector<int64_t>(result.completedRequestIds.begin(), result.completedRequestIds.end());
auto errorRequestIds
= std::vector<int64_t>(result.errorRequestIds.begin(), result.errorRequestIds.end());
return py::make_tuple(completedRequestIds, errorRequestIds);
},
py::arg("at_least_request_num") = std::nullopt, py::arg("mark_complete") = false)
.def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus,
py::call_guard<py::gil_scoped_release>())
.def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete)
.def("cancel_request", &BaseCacheTransceiver::cancelRequest);
py::enum_<executor::kv_cache::CacheState::AttentionType>(m, "AttentionType")
.value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT)
.value("MLA", executor::kv_cache::CacheState::AttentionType::kMLA);
py::classh<tb::CacheTransceiver, tb::BaseCacheTransceiver>(m, "CacheTransceiver")
.def(py::init<tb::kv_cache_manager::BaseKVCacheManager*, std::vector<SizeType32>, SizeType32, SizeType32,
runtime::WorldConfig, std::vector<SizeType32>, nvinfer1::DataType,
executor::kv_cache::CacheState::AttentionType, std::optional<executor::CacheTransceiverConfig>>(),
py::arg("cache_manager"), py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"),
py::arg("tokens_per_block"), py::arg("world_config"), py::arg("attention_layer_num_per_pp"),
py::arg("dtype"), py::arg("attention_type"), py::arg("cache_transceiver_config") = std::nullopt);
py::classh<tb::CacheTransceiverComm>(m, "CacheTransceiverComm")
.def(py::init(
[](py::object pg_obj, std::string pybind11_abi)
{
return new CacheTransceiverComm(
common::get_intrusive_ptr<c10d::ProcessGroup, py::error_already_set>(
pg_obj.ptr(), pybind11_abi));
}),
py::arg("process_group"), py::arg("pybind11_abi"))
.def("get_rank", &tb::CacheTransceiverComm::getRank)
.def("get_size", &tb::CacheTransceiverComm::getSize)
.def("split", &tb::CacheTransceiverComm::split, py::arg("color"), py::arg("key"))
.def(
"allgather",
[](tb::CacheTransceiverComm const& self, int64_t input)
{
std::vector<int64_t> out(static_cast<size_t>(self.getSize()));
c10d::AllgatherOptions options;
bool ok = self.allgather(input, std::ref(out), options);
return py::make_tuple(ok, out);
},
py::arg("input"))
.def(
"allgather",
[](tb::CacheTransceiverComm const& self, double input)
{
std::vector<double> out(static_cast<size_t>(self.getSize()));
c10d::AllgatherOptions options;
bool ok = self.allgather(input, std::ref(out), options);
return py::make_tuple(ok, out);
},
py::arg("input"))
.def(
"allgather",
[](tb::CacheTransceiverComm const& self, char input)
{
std::vector<char> out(static_cast<size_t>(self.getSize()));
c10d::AllgatherOptions options;
bool ok = self.allgather(input, std::ref(out), options);
return py::make_tuple(ok, out);
},
py::arg("input"))
.def(
"allgatherv",
[](tb::CacheTransceiverComm const& self, std::vector<int64_t> input, std::vector<int> const& sizes)
{
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
std::vector<int64_t> output(total_size);
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
return py::make_tuple(ok, output);
},
py::arg("input"), py::arg("sizes"))
.def(
"allgatherv",
[](tb::CacheTransceiverComm const& self, std::vector<double> input, std::vector<int> const& sizes)
{
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
std::vector<double> output(total_size);
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
return py::make_tuple(ok, output);
},
py::arg("input"), py::arg("sizes"))
.def(
"allgatherv",
[](tb::CacheTransceiverComm const& self, std::vector<char> input, std::vector<int> const& sizes)
{
int total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
std::vector<char> output(total_size);
bool ok = self.allgatherv(std::ref(input), std::ref(output), std::cref(sizes));
return py::make_tuple(ok, output);
},
py::arg("input"), py::arg("sizes"));
py::class_<tb::kv_cache_manager::CacheTransBufferManager>(m, "CacheTransBufferManager")
.def(py::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t>>(), py::arg("cache_manager"),
py::arg("max_num_tokens") = std::nullopt)
.def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize,
py::arg("cache_size_bytes_per_token_per_window"), py::arg("tokens_per_block"),
py::arg("cache_transceiver_config") = py::none());
}

View File

@ -1,30 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::batch_manager
{
class CacheTransceiverBindings
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::batch_manager

View File

@ -1,47 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h"
namespace
{
using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager;
namespace tb = tensorrt_llm::batch_manager;
class PyKvCacheConnectorManager : public KvCacheConnectorManager, py::trampoline_self_life_support
{
public:
using KvCacheConnectorManager::KvCacheConnectorManager;
SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override
{
PYBIND11_OVERRIDE_PURE_NAME(SizeType32, KvCacheConnectorManager, "get_num_new_matched_tokens",
getNumNewMatchedTokens, request, numComputedTokens);
}
};
} // namespace
void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(py::module_& m)
{
py::class_<tb::kv_connector::KvCacheConnectorManager, PyKvCacheConnectorManager, py::smart_holder>(
m, "KvCacheConnectorManager")
.def(py::init<>())
.def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens,
py::arg("request"), py::arg("num_computed_tokens"));
}

View File

@ -1,39 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheManagerConnectorBindings
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
namespace tensorrt_llm::pybind::batch_manager::kv_connector
{
using namespace tensorrt_llm::batch_manager::kv_connector;
} // namespace tensorrt_llm::pybind::batch_manager::kv_connector

View File

@ -1,570 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kvCacheManager.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/pybind/common/bindTypes.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <ATen/ATen.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
namespace tb = tensorrt_llm::batch_manager;
namespace tbc = tensorrt_llm::batch_manager::kv_connector;
namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager;
namespace tr = tensorrt_llm::runtime;
namespace py = pybind11;
using BlockKey = tbk::BlockKey;
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
using VecTokens = std::vector<TokenIdType>;
using CudaStreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
namespace
{
std::optional<tensorrt_llm::runtime::ITensor::UniquePtr> from_torch(std::optional<at::Tensor> torchPtr)
{
if (torchPtr)
{
return tr::TorchView::of(torchPtr.value());
}
return std::nullopt;
}
class PyKvCacheManager : public tbk::BaseKVCacheManager
{
public:
// using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors
void allocatePools(bool useUvm = false) override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, allocatePools, useUvm);
}
void releasePools() override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, releasePools);
}
void startScheduling() override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, startScheduling);
}
SizeType32 getTokensPerBlock() const override
{
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getTokensPerBlock);
}
SizeType32 getMaxNumBlocks() const override
{
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getMaxNumBlocks);
}
SizeType32 getNumPools() const override
{
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getNumPools);
}
tbk::KvCacheStats getKvCacheStats() const override
{
PYBIND11_OVERLOAD_PURE(tbk::KvCacheStats, tbk::BaseKVCacheManager, getKvCacheStats);
}
void addToken(tb::LlmRequest::RequestIdType requestId) override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, addToken, requestId);
}
void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
tensorrt_llm::common::OptionalRef<tb::LlmRequest> llmRequest = std::nullopt) override
{
PYBIND11_OVERLOAD_PURE(
void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth, llmRequest);
}
std::optional<tbk::KVCacheBlock::IdType> removeSequence(tb::LlmRequest::RequestIdType requestId,
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest = std::nullopt,
bool pinOnRelease = false) override
{
PYBIND11_OVERLOAD_PURE(std::optional<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, removeSequence,
requestId, llmRequest, pinOnRelease);
}
std::vector<tbk::KVCacheBlock::IdType> storeBlocksForReuse(tb::LlmRequest::RequestIdType requestId,
tensorrt_llm::common::OptionalRef<tb::LlmRequest const> llmRequest, bool pinBlocks) override
{
PYBIND11_OVERLOAD_PURE(std::vector<tbk::KVCacheBlock::IdType>, tbk::BaseKVCacheManager, storeBlocksForReuse,
requestId, llmRequest, pinBlocks);
}
tbk::GenerationRequest const& getSequence(tb::LlmRequest::RequestIdType requestId) const override
{
PYBIND11_OVERLOAD_PURE(tbk::GenerationRequest const&, tbk::BaseKVCacheManager, getSequence, requestId);
}
void schedulingRemoveSequence(tb::LlmRequest::RequestIdType requestId) override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, schedulingRemoveSequence, requestId);
}
tensorrt_llm::runtime::ITensor::SharedPtr getBlockPoolPointers() const override
{
PYBIND11_OVERLOAD_PURE(
tensorrt_llm::runtime::ITensor::UniquePtr, tbk::BaseKVCacheManager, getBlockPoolPointers);
}
tensorrt_llm::runtime::ITensor::SharedPtr getLayerToPoolMapping() const override
{
PYBIND11_OVERLOAD_PURE(
tensorrt_llm::runtime::ITensor::UniquePtr, tbk::BaseKVCacheManager, getLayerToPoolMapping);
}
void getBlockOffsetsOfBatch(tensorrt_llm::runtime::ITensor& output, SizeType32 firstBatchSlotIdx,
SizeType32 batchSize, SizeType32 beamWidth) const override
{
PYBIND11_OVERLOAD_PURE(
void, tbk::BaseKVCacheManager, getBlockOffsetsOfBatch, output, firstBatchSlotIdx, batchSize, beamWidth);
}
SizeType32 copyBlockOffsets(tensorrt_llm::runtime::ITensor& output, SizeType32 outputSlotOffset,
tb::LlmRequest::RequestIdType requestId) const override
{
PYBIND11_OVERLOAD_PURE(
SizeType32, tbk::BaseKVCacheManager, copyBlockOffsets, output, outputSlotOffset, requestId);
}
bool isEnableBlockReuse() const override
{
PYBIND11_OVERLOAD_PURE(bool, tbk::BaseKVCacheManager, isEnableBlockReuse);
}
void rewindKVCache(tb::LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, rewindKVCache, requestId, rewindLengths);
}
bool isCrossKv() const override
{
PYBIND11_OVERLOAD_PURE(bool, tbk::BaseKVCacheManager, isCrossKv);
}
std::optional<BlockKey> findNewContextBlock(
VecUniqueTokens const& uniqueTokens, tb::LlmRequest const& llmRequest) const override
{
PYBIND11_OVERLOAD_PURE(
std::optional<BlockKey>, tbk::BaseKVCacheManager, findNewContextBlock, uniqueTokens, llmRequest);
}
void storeContextBlocks(tb::LlmRequest const& llmRequest) override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, storeContextBlocks, llmRequest);
}
std::vector<std::vector<SizeType32>> const& getCacheBlockIds(
tb::LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override
{
PYBIND11_OVERLOAD_PURE(std::vector<std::vector<SizeType32>> const&, tbk::BaseKVCacheManager, getCacheBlockIds,
requestId, windowSize);
}
std::vector<std::vector<std::vector<SizeType32>>> getBatchCacheBlockIds(
std::vector<tb::LlmRequest::RequestIdType> const& requestIds, SizeType32 windowSize) const override
{
PYBIND11_OVERLOAD_PURE(std::vector<std::vector<std::vector<SizeType32>>>, tbk::BaseKVCacheManager,
getBatchCacheBlockIds, requestIds, windowSize);
}
SizeType32 getUsedNumBlocks() const override
{
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getUsedNumBlocks);
}
SizeType32 getNumFreeBlocks() const override
{
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getNumFreeBlocks);
}
tbk::BlockManager const& getBlockManager() const override
{
PYBIND11_OVERLOAD_PURE(tbk::BlockManager const&, tbk::BaseKVCacheManager, getBlockManager);
}
std::deque<tensorrt_llm::executor::KVCacheEvent> getLatestEvents(
std::optional<std::chrono::milliseconds> timeout = std::nullopt) const override
{
PYBIND11_OVERLOAD_PURE(
std::deque<tensorrt_llm::executor::KVCacheEvent>, tbk::BaseKVCacheManager, getLatestEvents, timeout);
}
tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 poolIdx) const override
{
PYBIND11_OVERLOAD_PURE(
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getPrimaryPool, poolIdx);
}
tensorrt_llm::runtime::ITensor::SharedPtr getIndexerKCachePool() const override
{
PYBIND11_OVERLOAD_PURE(
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getIndexerKCachePool);
}
tensorrt_llm::runtime::ITensor::SharedPtr getUniquePrimaryPool() const override
{
PYBIND11_OVERLOAD_PURE(
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getUniquePrimaryPool);
}
SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override
{
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx);
}
void syncTransferManagerWithBufferManager() override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, syncTransferManagerWithBufferManager);
}
void refreshBlocks() override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks);
}
void flushIterationEvents() override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, flushIterationEvents);
}
};
// TODO: Deduplicate executor bindings KvCacheStats
class PyBasePeftCacheManager : public tb::BasePeftCacheManager
{
public:
~PyBasePeftCacheManager() override = default;
void addRequestPeft(tb::BasePeftCacheManager::LlmRequestPtr llmRequest, bool tryGpuCache = true) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BasePeftCacheManager, addRequestPeft, llmRequest, tryGpuCache);
}
tb::BasePeftCacheManager::PeftTable ensureBatch(tb::RequestVector const& contextRequests,
tb::RequestVector const& generationRequests, bool resetGpuCache = false) override
{
PYBIND11_OVERLOAD_PURE(tb::BasePeftCacheManager::PeftTable, tb::BasePeftCacheManager, ensureBatch,
contextRequests, generationRequests, resetGpuCache);
}
void resetDeviceCache() override
{
PYBIND11_OVERLOAD_PURE(void, tb::BasePeftCacheManager, resetDeviceCache);
}
void markRequestDone(tb::LlmRequest const& llmReq, bool pause = false) override
{
PYBIND11_OVERLOAD_PURE(void, tb::BasePeftCacheManager, markRequestDone, llmReq, pause);
}
tr::SizeType32 getMaxDevicePages() const override
{
PYBIND11_OVERLOAD_PURE(tr::SizeType32, tb::BasePeftCacheManager, getMaxDevicePages);
}
tr::SizeType32 getMaxHostPages() const override
{
PYBIND11_OVERLOAD_PURE(tr::SizeType32, tb::BasePeftCacheManager, getMaxHostPages);
}
tr::SizeType32 determineNumPages(std::shared_ptr<tb::LlmRequest> llmRequest) const override
{
PYBIND11_OVERLOAD_PURE(tr::SizeType32, tb::BasePeftCacheManager, determineNumPages, llmRequest);
}
bool enabled() const override
{
PYBIND11_OVERLOAD_PURE(bool, tb::BasePeftCacheManager, enabled);
}
};
} // namespace
void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
{
py::class_<tbk::KvCacheStats>(m, "KvCacheStats")
.def(py::init<>())
.def_readwrite("max_num_blocks", &tbk::KvCacheStats::maxNumBlocks)
.def_readwrite("free_num_blocks", &tbk::KvCacheStats::freeNumBlocks)
.def_readwrite("used_num_blocks", &tbk::KvCacheStats::usedNumBlocks)
.def_readwrite("tokens_per_block", &tbk::KvCacheStats::toksPerBlock)
.def_readwrite("alloc_total_blocks", &tbk::KvCacheStats::allocTotalBlocks)
.def_readwrite("alloc_new_blocks", &tbk::KvCacheStats::allocNewBlocks)
.def_readwrite("reused_blocks", &tbk::KvCacheStats::reusedBlocks)
.def_readwrite("missed_blocks", &tbk::KvCacheStats::missedBlocks)
.def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate)
.def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize)
.def_readonly("allocated_bytes", &tbk::KvCacheStats::allocatedBytes);
py::class_<tbk::TempAttentionWindowInputs>(m, "TempAttentionWindowInputs")
.def(py::init<>())
.def_readwrite("paged_context_fmha", &tbk::TempAttentionWindowInputs::pagedContextFMHA)
.def_readwrite("max_input_len", &tbk::TempAttentionWindowInputs::maxInputLen)
.def_readwrite("max_num_tokens", &tbk::TempAttentionWindowInputs::maxNumTokens);
py::class_<tbk::BlockKey>(m, "BlockKey")
.def(py::init<>())
.def(py::init<VecTokens const&, std::optional<tr::LoraTaskIdType>>(), py::arg("tokens"),
py::arg("lora_task_id") = std::nullopt)
.def(py::init<bool, std::optional<tr::LoraTaskIdType>, VecUniqueTokens const&>(), py::arg("uses_extra_ids"),
py::arg("lora_task_id"), py::arg("unique_tokens"))
.def_readonly("uses_extra_ids", &tbk::BlockKey::usesExtraIds)
.def_readonly("lora_task_id", &tbk::BlockKey::loraTaskId)
.def_readonly("unique_tokens", &tbk::BlockKey::uniqueTokens);
py::class_<tbk::BlockKeyHasher>(m, "BlockKeyHasher")
.def_static("hash", &tbk::BlockKeyHasher::hash, py::arg("block_key"), py::arg("parent_hash") = 0);
py::class_<tbk::KVCacheEventManager, std::shared_ptr<tbk::KVCacheEventManager>>(m, "KVCacheEventManager")
.def(py::init<size_t, std::optional<SizeType32>, std::optional<SizeType32>, SizeType32>(),
py::arg("max_kv_event_entries"), py::arg("attention_dp_rank") = std::nullopt,
py::arg("attention_dp_size") = std::nullopt, py::arg("attention_dp_events_gather_period_ms") = 5);
py::classh<tbk::BaseKVCacheManager, PyKvCacheManager>(m, "BaseKVCacheManager")
.def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, py::arg("config"),
py::arg("is_cross_attention"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config"),
py::arg("window_size_to_layers"), py::arg("allotted_primary_mem_bytes"),
py::arg("allotted_secondary_mem_bytes"), py::arg("extra_cost_memory"), py::arg("kv_factor"),
py::call_guard<py::gil_scoped_release>())
.def("allocate_pools", &BaseKVCacheManager::allocatePools, py::call_guard<py::gil_scoped_release>())
.def("release_pools", &BaseKVCacheManager::releasePools, py::call_guard<py::gil_scoped_release>())
.def("start_scheduling", &BaseKVCacheManager::startScheduling, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("tokens_per_block", &BaseKVCacheManager::getTokensPerBlock)
.def_property_readonly("max_num_blocks", &BaseKVCacheManager::getMaxNumBlocks)
.def_property_readonly("num_pools", &BaseKVCacheManager::getNumPools)
.def("get_kv_cache_stats", &BaseKVCacheManager::getKvCacheStats, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("max_blocks_per_seq",
[](tbk::BaseKVCacheManager& self) { return self.getOffsetTableDimensions().maxBlocksPerSeq; })
.def("get_needed_blocks_one_step", &BaseKVCacheManager::getNeededBlocksOneStep,
py::call_guard<py::gil_scoped_release>())
.def("get_remaining_blocks_to_completion", &BaseKVCacheManager::getRemainingBlocksToCompletion,
py::call_guard<py::gil_scoped_release>())
.def("add_token", &BaseKVCacheManager::addToken, py::call_guard<py::gil_scoped_release>())
.def("add_sequence", &BaseKVCacheManager::addSequence, py::call_guard<py::gil_scoped_release>())
.def("remove_sequence", &BaseKVCacheManager::removeSequence, py::call_guard<py::gil_scoped_release>())
.def("pin_blocks", &BaseKVCacheManager::pinBlocks, py::call_guard<py::gil_scoped_release>())
.def("scheduling_remove_sequence", &BaseKVCacheManager::schedulingRemoveSequence,
py::call_guard<py::gil_scoped_release>())
.def(
"get_block_pool_pointers",
[](tbk::BaseKVCacheManager& self)
{
std::optional<at::Tensor> block_pool_pointers{std::nullopt};
auto tensor = self.getBlockPoolPointers();
if (tensor)
{
std::shared_ptr<tensorrt_llm::runtime::ITensor> _tensor = std::move(tensor);
block_pool_pointers = tr::Torch::tensor(_tensor);
}
return block_pool_pointers;
},
py::call_guard<py::gil_scoped_release>())
.def(
"get_block_scale_pool_pointers",
[](tbk::BaseKVCacheManager& self)
{
std::optional<at::Tensor> block_scale_pool_pointers{std::nullopt};
auto tensor = self.getBlockScalePoolPointers();
if (tensor)
{
std::shared_ptr<tensorrt_llm::runtime::ITensor> _tensor = std::move(tensor);
block_scale_pool_pointers = tr::Torch::tensor(_tensor);
}
return block_scale_pool_pointers;
},
py::call_guard<py::gil_scoped_release>())
.def(
"get_layer_to_pool_mapping",
[](tbk::BaseKVCacheManager& self)
{
std::optional<at::Tensor> layer_to_pool_mapping{std::nullopt};
auto tensor = self.getLayerToPoolMapping();
if (tensor)
{
std::shared_ptr<tensorrt_llm::runtime::ITensor> _tensor = std::move(tensor);
layer_to_pool_mapping = tr::Torch::tensor(_tensor);
}
return layer_to_pool_mapping;
},
py::call_guard<py::gil_scoped_release>())
.def(
"get_primary_pool_data",
[](tbk::BaseKVCacheManager& self, SizeType32 layer_idx) -> at::Tensor
{
auto pool = tr::Torch::tensor(self.getPrimaryPool(layer_idx));
auto pool_layer_idx = self.getPoolLayerIdx(layer_idx);
return pool.index({torch::indexing::Slice(), pool_layer_idx});
},
py::call_guard<py::gil_scoped_release>())
.def(
"get_indexer_k_cache_pool_data",
[](tbk::BaseKVCacheManager& self, SizeType32 layer_idx) -> at::Tensor
{
auto pool = tr::Torch::tensor(self.getIndexerKCachePool());
return pool.index({torch::indexing::Slice(), layer_idx});
},
py::call_guard<py::gil_scoped_release>())
.def(
"get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); },
py::call_guard<py::gil_scoped_release>())
.def(
"get_block_offsets_of_batch",
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize,
SizeType32 beamWidth)
{
auto _output = from_torch(output);
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
self.getBlockOffsetsOfBatch(*(_output.value()), firstBatchSlotIdx, batchSize, beamWidth);
},
py::call_guard<py::gil_scoped_release>())
.def(
"copy_block_offsets",
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 outputSlotOffset,
tb::LlmRequest::RequestIdType requestId)
{
auto _output = from_torch(output);
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
auto maxBlockCount = self.copyBlockOffsets(*(_output.value()), outputSlotOffset, requestId);
return maxBlockCount;
},
py::call_guard<py::gil_scoped_release>())
.def(
"copy_batch_block_offsets",
[](tbk::BaseKVCacheManager& self, at::Tensor output,
std::vector<tb::LlmRequest::RequestIdType> const& requestIds, SizeType32 const beamWidth,
SizeType32 const offset)
{
auto _output = from_torch(output);
TLLM_CHECK_WITH_INFO(_output.has_value(), "Invalid output tensor.");
for (size_t i = 0; i < requestIds.size(); ++i)
{
self.copyBlockOffsets(*(_output.value()), i * beamWidth + offset, requestIds[i]);
}
},
py::call_guard<py::gil_scoped_release>())
.def(
"get_latest_events",
[](tbk::BaseKVCacheManager& self, std::optional<double> timeout_ms = std::nullopt)
{
if (timeout_ms)
{
return self.getLatestEvents(std::chrono::milliseconds(static_cast<int64_t>(*timeout_ms)));
}
return self.getLatestEvents(std::nullopt);
},
py::arg("timeout_ms") = std::nullopt, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("enable_block_reuse", &BaseKVCacheManager::isEnableBlockReuse)
.def("rewind_kv_cache", &BaseKVCacheManager::rewindKVCache, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("cross_kv", &BaseKVCacheManager::isCrossKv)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, py::call_guard<py::gil_scoped_release>())
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
py::call_guard<py::gil_scoped_release>())
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, py::arg("unique_tokens"),
py::arg("llm_request"), py::call_guard<py::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard<py::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
py::call_guard<py::gil_scoped_release>())
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
py::call_guard<py::gil_scoped_release>())
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
py::call_guard<py::gil_scoped_release>())
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard<py::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
py::enum_<tbk::CacheType>(m, "CacheType")
.value("SELF", tbk::CacheType::kSELF)
.value("CROSS", tbk::CacheType::kCROSS)
.value("SELFKONLY", tbk::CacheType::kSELFKONLY);
py::classh<tbk::KVCacheManager, tbk::BaseKVCacheManager>(m, "KVCacheManager")
.def(py::init<std::vector<SizeType32> const&, SizeType32, SizeType32,
std::map<SizeType32, std::tuple<SizeType32, SizeType32>> const&, SizeType32, SizeType32,
std::vector<SizeType32> const&, std::optional<tbk::TempAttentionWindowInputs> const&,
nvinfer1::DataType, SizeType32, bool, int64_t, bool, bool, tbk::CacheType,
std::optional<tensorrt_llm::executor::RetentionPriority>, std::shared_ptr<tbk::KVCacheEventManager>,
bool, bool, std::shared_ptr<tbc::KvCacheConnectorManager>, bool, SizeType32, SizeType32>(),
py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"), py::arg("tokens_per_block"),
py::arg("blocks_per_window"), py::arg("max_num_sequences"), py::arg("max_beam_width"),
py::arg("max_attention_window_vec"), py::arg("temp_attention_window_inputs"), py::arg("dtype"),
py::arg("sink_token_length"), py::arg("stream"), py::arg("max_sequence_length"),
py::arg("enable_block_reuse") = false, py::arg("onboard_blocks") = true,
py::arg_v("cache_type", tbk::CacheType::kSELF, "bindings.internal.batch_manager.CacheType.SELF"),
py::arg("secondary_offload_min_priority") = std::nullopt, py::arg("event_manager") = nullptr,
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true,
py::arg("kv_connector_manager") = nullptr, py::arg("enable_indexer_k_cache") = false,
py::arg("indexer_k_cache_quant_block_size") = 128, py::arg("indexer_k_cache_index_head_dim") = 0,
py::call_guard<py::gil_scoped_release>())
.def(
"scheduling_has_free_blocks",
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
py::arg("num_required"), py::arg("window_size"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
}
void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
{
py::classh<tb::BasePeftCacheManager, PyBasePeftCacheManager>(m, "BasePeftCacheManager")
.def("add_request_peft", &tb::BasePeftCacheManager::addRequestPeft, py::arg("request"),
py::arg("try_gpu_cache") = true, py::call_guard<py::gil_scoped_release>())
.def(
"ensure_batch",
[](tb::BasePeftCacheManager& self, tb::RequestVector const& contextRequests,
tb::RequestVector const& generationRequests, bool resetGpuCache)
{ return self.ensureBatch(contextRequests, generationRequests, resetGpuCache); },
py::arg("context_requests"), py::arg("generation_requests"), py::arg("reset_gpu_cache") = false,
py::call_guard<py::gil_scoped_release>())
.def(
"reset_device_cache", &tb::BasePeftCacheManager::resetDeviceCache, py::call_guard<py::gil_scoped_release>())
.def("mark_request_done", &tb::BasePeftCacheManager::markRequestDone, py::arg("request"),
py::arg("pause") = false, py::call_guard<py::gil_scoped_release>())
.def_property_readonly("max_device_pages", &tb::BasePeftCacheManager::getMaxDevicePages)
.def_property_readonly("max_host_pages", &tb::BasePeftCacheManager::getMaxHostPages)
.def("determine_num_pages", &tb::BasePeftCacheManager::determineNumPages, py::arg("request"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("enabled", &tb::BasePeftCacheManager::enabled);
py::classh<tb::PeftCacheManager, tb::BasePeftCacheManager>(m, "PeftCacheManager")
.def(py::init<tb::PeftCacheManagerConfig, tr::ModelConfig, tr::WorldConfig, tr::BufferManager>(),
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"),
py::call_guard<py::gil_scoped_release>())
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, py::arg("taskId"),
py::call_guard<py::gil_scoped_release>())
.def("is_task_cached_device", &tb::PeftCacheManager::isTaskCachedDevice, py::arg("taskId"),
py::call_guard<py::gil_scoped_release>())
.def("ensure_batch_map_task_id", &tb::PeftCacheManager::ensureBatchMapTaskId, py::arg("context_requests"),
py::arg("generation_requests"), py::arg("reset_gpu_cache") = false,
py::call_guard<py::gil_scoped_release>());
py::classh<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>());
}

View File

@ -1,39 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheManagerBindings
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
namespace tensorrt_llm::batch_manager
{
class BasePeftCacheManagerBindings
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::batch_manager

View File

@ -1,111 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kvCacheManagerV2Utils.h"
#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
namespace py = pybind11;
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
void KVCacheManagerV2UtilsBindings::initBindings(py::module_& module)
{
// Bind DiskAddress struct
py::class_<DiskAddress>(module, "DiskAddress")
.def(py::init<int, ssize_t>(), py::arg("fd"), py::arg("pos"))
.def_readwrite("fd", &DiskAddress::fd)
.def_readwrite("pos", &DiskAddress::pos);
// Bind Task template instantiations
py::class_<Task<DiskAddress, DiskAddress>>(module, "DiskToDiskTask")
.def(py::init<DiskAddress, DiskAddress>(), py::arg("dst"), py::arg("src"))
.def_readwrite("dst", &Task<DiskAddress, DiskAddress>::dst)
.def_readwrite("src", &Task<DiskAddress, DiskAddress>::src);
py::class_<Task<MemAddress, DiskAddress>>(module, "DiskToHostTask")
.def(py::init<MemAddress, DiskAddress>(), py::arg("dst"), py::arg("src"))
.def_readwrite("dst", &Task<MemAddress, DiskAddress>::dst)
.def_readwrite("src", &Task<MemAddress, DiskAddress>::src);
py::class_<Task<DiskAddress, MemAddress>>(module, "HostToDiskTask")
.def(py::init<DiskAddress, MemAddress>(), py::arg("dst"), py::arg("src"))
.def_readwrite("dst", &Task<DiskAddress, MemAddress>::dst)
.def_readwrite("src", &Task<DiskAddress, MemAddress>::src);
py::class_<Task<MemAddress, MemAddress>>(module, "MemToMemTask")
.def(py::init<MemAddress, MemAddress>(), py::arg("dst"), py::arg("src"))
.def_readwrite("dst", &Task<MemAddress, MemAddress>::dst)
.def_readwrite("src", &Task<MemAddress, MemAddress>::src);
// Bind copy functions
module.def(
"copy_disk_to_disk",
[](std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDiskToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from disk to disk using CUDA host function");
module.def(
"copy_disk_to_host",
[](std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDiskToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from disk to host using CUDA host function");
module.def(
"copy_host_to_disk",
[](std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyHostToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from host to disk using CUDA host function");
module.def(
"copy_host_to_host",
[](std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyHostToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from host to host using CUDA host function");
module.def(
"copy_host_to_device",
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyHostToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from host to device using CUDA kernels");
module.def(
"copy_device_to_host",
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDeviceToHost(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from device to host using CUDA kernels");
module.def(
"copy_device_to_device",
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDeviceToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from device to device using CUDA kernels");
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -1,29 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <pybind11/pybind11.h>
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
class KVCacheManagerV2UtilsBindings
{
public:
static void initBindings(pybind11::module_& module);
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -1,131 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "llmRequest.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/common/bindTypes.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <memory>
namespace tb = tensorrt_llm::batch_manager;
namespace tr = tensorrt_llm::runtime;
namespace tle = tensorrt_llm::executor;
using namespace tensorrt_llm::pybind::batch_manager;
using LlmRequestPtr = std::shared_ptr<tb::LlmRequest>;
using RequestList = std::list<LlmRequestPtr>;
namespace
{
std::optional<tb::LlmRequest::TensorPtr> from_torch(std::optional<LlmRequest::TensorPtr> torchPtr)
{
if (torchPtr)
{
return tr::TorchView::of(torchPtr.value());
}
return std::nullopt;
}
} // namespace
std::optional<tb::LlmRequest::LogitsPostProcessor> LlmRequest::callbackAdapter(
std::optional<LlmRequest::LogitsPostProcessor> callback)
{
if (!callback)
{
return std::nullopt;
}
return [callback](RequestIdType reqId, tr::ITensor::SharedPtr& tensor, tb::LlmRequest::BeamTokens const& tokens,
tr::BufferManager::CudaStreamPtr stream, std::optional<RequestIdType> clientId)
{
at::Tensor atTensor = tr::Torch::tensor(tensor);
callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap(), clientId);
};
}
std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
{
auto const draftTokens = std::make_shared<std::vector<TokenIdType>>(*mDraftTokens.get());
auto const optDraftTokens = std::optional<std::shared_ptr<std::vector<TokenIdType>>>(draftTokens);
auto const encoderInputTokens = mEncoderTokens.has_value()
? std::make_shared<std::vector<TokenIdType>>(*mEncoderTokens.value().get())
: nullptr;
auto const optEncoderInputTokens = std::optional<std::shared_ptr<std::vector<TokenIdType>>>(encoderInputTokens);
return std::make_shared<tb::LlmRequest>( //
mRequestId, //
mMaxNewTokens, //
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), //
mSamplingConfig, //
mIsStreaming, //
mEndId, //
mPadId, //
from_torch(mEmbeddingBias), //
from_torch(mBadWordsList), //
from_torch(mStopWordsList), //
mPositionIds, //
from_torch(mPromptEmbeddingTable), //
mPromptVocabSize, //
mMultimodalHashes, //
mMultimodalPositions, //
mMultimodalLengths, //
from_torch(mMultimodalEmbedding), //
from_torch(mMropeRotaryCosSin), //
mMropePositionDeltas, //
mLoraTaskId, //
from_torch(mLoraWeights), //
from_torch(mLoraConfig), //
mLookaheadConfig, //
mKvCacheRetentionConfig, //
mReturnLogProbs, //
mReturnContextLogits, //
mReturnGenerationLogits, //
optDraftTokens, //
from_torch(mDraftLogits), //
mExcludeInputFromOutput, //
callbackAdapter(mLogitsPostProcessor), //
mApplyLogitsPostProcessorBatched, //
optEncoderInputTokens, //
mReturnEncoderOutput, //
mClientId, //
mPriority, //
from_torch(mEncoderInputFeatures), //
mEncoderOutputLength, //
from_torch(mCrossAttentionMask), //
getLlmRequestType(), //
std::nullopt, // inputTokenExtraIds
mNumReturnSequences, //
mEagleConfig, //
from_torch(mSkipCrossAttnBlocks), //
false, // returnPerfMetrics
mGuidedDecodingParams, //
mLanguageAdapterUid, //
mAllottedTimeMs, //
mContextPhaseParams, //
mCacheSaltID, //
mPerfMetrics.timingMetrics.arrivalTime //
);
}

View File

@ -1,162 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
#include <memory>
#include <optional>
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::batch_manager
{
namespace tb = tensorrt_llm::batch_manager;
/* Unfortunately, torch's default pybind bindings don't know about c10::cuda::CUDAStream,
* so we have to pass the more generic c10::Stream, and convert it back to a full-fledged
* torch.cuda.Stream in python. See example in test/bindings/test_gpt_manager.py
*/
class LlmRequest : public tb::GenericLlmRequest<at::Tensor, c10::Stream>
{
public:
using Base = GenericLlmRequest<at::Tensor, c10::Stream>;
using TensorPtr = Base::TensorPtr;
using SizeType32 = Base::SizeType32;
using TokenIdType = Base::TokenIdType;
using RequestIdType = Base::RequestIdType;
using LoraTaskIdType = Base::LoraTaskIdType;
using VecLogProbs = Base::VecLogProbs;
using BeamTokens = Base::BeamTokens;
using VecTokens = Base::VecTokens;
using VecTokenExtraIds = Base::VecTokenExtraIds;
using LogitsPostProcessor = Base::LogitsPostProcessor;
using CacheSaltIDType = Base::CacheSaltIDType;
LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector<TokenIdType> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<std::vector<SizeType32>> positionIds = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType32> promptVocabSize = std::nullopt,
std::optional<std::vector<std::vector<SizeType32>>> multimodalHashes = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalPositions = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalLengths = std::nullopt,
std::optional<TensorPtr> multimodalEmbedding = std::nullopt,
std::optional<TensorPtr> mropeRotaryCosSin = std::nullopt,
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
std::optional<TensorPtr> loraConfig = std::nullopt,
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
std::optional<executor::KvCacheRetentionConfig> kvCacheRetentionConfig = std::nullopt,
bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
std::optional<VecTokens> draftTokens = std::nullopt, std::optional<TensorPtr> draftLogits = std::nullopt,
bool excludeInputFromOutput = false, std::optional<LogitsPostProcessor> logitsPostProcessor = std::nullopt,
bool applyLogitsPostProcessorBatched = false, std::optional<VecTokens> encoderInputTokens = std::nullopt,
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority,
std::optional<TensorPtr> encoderInputFeatures = std::nullopt,
std::optional<SizeType32> encoderOutputLength = std::nullopt,
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
tb::LlmRequestType llmRequestType = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
std::optional<VecTokenExtraIds> inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1,
std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt, bool returnPerfMetrics = false,
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt)
: Base(requestId, //
maxNewTokens, //
std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)), //
samplingConfig, //
isStreaming, //
endId, //
padId, //
embeddingBias, //
badWordsList, //
stopWordsList, //
positionIds.has_value() ? std::make_shared<std::vector<SizeType32>>(std::move(positionIds.value())) //
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt), //
promptEmbeddingTable, //
promptVocabSize, //
multimodalHashes.has_value()
? std::make_optional(
std::make_shared<std::vector<std::vector<SizeType32>>>(std::move(multimodalHashes.value()))) //
: std::optional<std::shared_ptr<std::vector<std::vector<SizeType32>>>>(std::nullopt), //
multimodalPositions.has_value()
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalPositions.value())) //
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt), //
multimodalLengths.has_value()
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalLengths.value())) //
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt), //
multimodalEmbedding, //
mropeRotaryCosSin, //
mropePositionDeltas, //
loraTaskId, //
loraWeights, //
loraConfig, //
lookaheadConfig, //
kvCacheRetentionConfig, //
returnLogProbs, //
returnContextLogits, //
returnGenerationLogits, //
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value())) //
: std::make_shared<VecTokens>(), //
draftLogits, //
excludeInputFromOutput, //
logitsPostProcessor, //
applyLogitsPostProcessorBatched, //
encoderInputTokens ? std::make_optional(std::make_shared<VecTokens>(std::move(*encoderInputTokens))) //
: std::optional<std::shared_ptr<VecTokens>>(std::nullopt), //
returnEncoderOutput, //
clientId, //
priority, //
encoderInputFeatures, //
encoderOutputLength, //
crossAttentionMask, //
llmRequestType, //
inputTokenExtraIds //
? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds))) //
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt), //
numReturnSequences, //
eagleConfig, //
skipCrossAttnBlocks, //
returnPerfMetrics, //
guidedDecodingParams, //
languageAdapterUid, //
allottedTimeMs, //
contextPhaseParams, //
cacheSaltID, //
arrivalTime //
)
{
}
static std::optional<tb::LlmRequest::LogitsPostProcessor> callbackAdapter(
std::optional<LlmRequest::LogitsPostProcessor> callback);
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -1,518 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <pybind11/cast.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <vector>
#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/pybind/batch_manager/algorithms.h"
#include "tensorrt_llm/pybind/batch_manager/bindings.h"
#include "tensorrt_llm/pybind/batch_manager/buffers.h"
#include "tensorrt_llm/pybind/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.h"
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/common/tllmExceptions.h"
#include "tensorrt_llm/pybind/executor/bindings.h"
#include "tensorrt_llm/pybind/process_group/bindings.h"
#include "tensorrt_llm/pybind/runtime/bindings.h"
#include "tensorrt_llm/pybind/testing/modelSpecBinding.h"
#include "tensorrt_llm/pybind/thop/bindings.h"
#include "tensorrt_llm/pybind/userbuffers/bindings.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h"
#include "tensorrt_llm/runtime/ipcNvlsMemory.h"
#include "tensorrt_llm/runtime/memoryCounters.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
namespace tpb = tensorrt_llm::pybind::batch_manager;
namespace tc = tensorrt_llm::common;
namespace tr = tensorrt_llm::runtime;
namespace tle = tensorrt_llm::executor;
using SizeType32 = tr::SizeType32;
using TokenIdType = tr::TokenIdType;
template <typename T>
using OptVec = std::optional<std::vector<T>>;
#if not defined(TRTLLM_PYBIND_MODULE)
#error "TRTLLM_PYBIND_MODULE must be defined"
#endif
namespace
{
tr::SamplingConfig makeSamplingConfig(std::vector<tr::SamplingConfig> const& configs)
{
return tr::SamplingConfig(configs);
}
} // namespace
PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
{
m.doc() = "TensorRT LLM Python bindings for C++ runtime";
m.attr("binding_type") = "pybind";
// Create MpiComm binding first since it's used in the executor bindings
py::classh<tensorrt_llm::mpi::MpiComm>(m, "MpiComm")
.def_static("rank",
[]()
{
auto& session = tensorrt_llm::mpi::MpiComm::session();
return session.tensorrt_llm::mpi::MpiComm::getRank();
})
.def_static("size",
[]()
{
auto& session = tensorrt_llm::mpi::MpiComm::session();
return session.tensorrt_llm::mpi::MpiComm::getSize();
})
.def_static("local_size",
[]()
{
auto& session = tensorrt_llm::mpi::MpiComm::localSession();
return session.tensorrt_llm::mpi::MpiComm::getSize();
})
.def_static("local_init", []() { tensorrt_llm::mpi::MpiComm::localSession(); })
.def_static("set_raw_mpi_session_by_fortran_handle",
[](int64_t fortran_handle) { tensorrt_llm::mpi::MpiComm::setRawSessionByFortran(fortran_handle); })
.def_static("split",
[](size_t color, size_t rank)
{
auto& world = tensorrt_llm::mpi::MpiComm::world();
tensorrt_llm::mpi::MpiComm::setSession(world.split(color, rank));
});
py::classh<tr::CudaStream>(m, "CudaStream")
.def(py::init(
[](py::object py_stream)
{
cudaStream_t stream = reinterpret_cast<cudaStream_t>(py_stream.cast<uintptr_t>());
return tr::CudaStream{stream};
}),
py::arg("stream_ptr"))
.def("get_device", &tr::CudaStream::getDevice);
// Create submodule for executor bindings.
auto mExecutor = m.def_submodule("executor", "Executor bindings");
auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime");
auto mInternalProcessGroup = mInternal.def_submodule("process_group", "PyTorch ProcessGroup internal bindings");
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
auto mInternalBatchManagerKvCacheV2Utils
= mInternalBatchManager.def_submodule("kv_cache_manager_v2_utils", "KV Cache Manager V2 Utils bindings");
auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings");
auto mExceptions = m.def_submodule("exceptions", "Exceptions internal bindings");
tensorrt_llm::pybind::executor::initBindings(mExecutor);
tensorrt_llm::pybind::runtime::initBindingsEarly(mInternalRuntime);
tensorrt_llm::pybind::common::initExceptionsBindings(mExceptions);
tensorrt_llm::pybind::thop::initBindings(mInternalThop);
auto buildInfo = m.def_submodule("BuildInfo");
buildInfo.attr("ENABLE_MULTI_DEVICE") = py::int_(ENABLE_MULTI_DEVICE);
py::class_<tb::PeftCacheManagerConfig>(m, "PeftCacheManagerConfig")
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32,
SizeType32, std::optional<float>, std::optional<size_t>, std::optional<std::string>>(),
py::arg("num_host_module_layer") = 0, py::arg("num_device_module_layer") = 0,
py::arg("optimal_adapter_size") = 8, py::arg("max_adapter_size") = 64, py::arg("num_put_workers") = 1,
py::arg("num_ensure_workers") = 1, py::arg("num_copy_streams") = 1,
py::arg("max_pages_per_block_host") = 24, py::arg("max_pages_per_block_device") = 8,
py::arg("device_cache_percent") = std::nullopt, py::arg("host_cache_size") = std::nullopt,
py::arg("lora_prefetch_dir") = std::nullopt)
.def_readwrite("num_host_module_layer", &tb::PeftCacheManagerConfig::numHostModuleLayer)
.def_readwrite("num_device_module_layer", &tb::PeftCacheManagerConfig::numDeviceModuleLayer)
.def_readwrite("optimal_adapter_size", &tb::PeftCacheManagerConfig::optimalAdapterSize)
.def_readwrite("max_adapter_size", &tb::PeftCacheManagerConfig::maxAdapterSize)
.def_readwrite("num_put_workers", &tb::PeftCacheManagerConfig::numPutWorkers)
.def_readwrite("num_ensure_workers", &tb::PeftCacheManagerConfig::numEnsureWorkers)
.def_readwrite("num_copy_streams", &tb::PeftCacheManagerConfig::numCopyStreams)
.def_readwrite("max_pages_per_block_host", &tb::PeftCacheManagerConfig::maxPagesPerBlockHost)
.def_readwrite("max_pages_per_block_device", &tb::PeftCacheManagerConfig::maxPagesPerBlockDevice)
.def_readwrite("device_cache_percent", &tb::PeftCacheManagerConfig::deviceCachePercent)
.def_readwrite("host_cache_size", &tb::PeftCacheManagerConfig::hostCacheSize)
.def_readwrite("lora_prefetch_dir", &tb::PeftCacheManagerConfig::loraPrefetchDir);
py::enum_<nvinfer1::DataType>(m, "DataType")
.value("FLOAT", nvinfer1::DataType::kFLOAT)
.value("HALF", nvinfer1::DataType::kHALF)
.value("INT8", nvinfer1::DataType::kINT8)
.value("INT32", nvinfer1::DataType::kINT32)
.value("BOOL", nvinfer1::DataType::kBOOL)
.value("UINT8", nvinfer1::DataType::kUINT8)
.value("FP8", nvinfer1::DataType::kFP8)
.value("BF16", nvinfer1::DataType::kBF16)
.value("INT64", nvinfer1::DataType::kINT64)
.value("NVFP4", nvinfer1::DataType::kFP4)
.export_values();
py::enum_<tr::ModelConfig::ModelVariant>(m, "GptModelVariant")
.value("GPT", tr::ModelConfig::ModelVariant::kGpt)
.value("GLM", tr::ModelConfig::ModelVariant::kGlm)
.value("CHATGLM", tr::ModelConfig::ModelVariant::kChatGlm)
.value("MAMBA", tr::ModelConfig::ModelVariant::kMamba)
.value("RECURRENTGEMMA", tr::ModelConfig::ModelVariant::kRecurrentGemma);
py::enum_<tr::ModelConfig::KVCacheType>(m, "KVCacheType")
.value("CONTINUOUS", tr::ModelConfig::KVCacheType::kCONTINUOUS)
.value("PAGED", tr::ModelConfig::KVCacheType::kPAGED)
.value("DISABLED", tr::ModelConfig::KVCacheType::kDISABLED)
.def("from_string", &tr::ModelConfig::KVCacheTypeFromString);
py::enum_<tr::ModelConfig::LayerType>(m, "LayerType")
.value("ATTENTION", tr::ModelConfig::LayerType::kATTENTION)
.value("RECURRENT", tr::ModelConfig::LayerType::kRECURRENT);
py::enum_<tr::LoraModule::ModuleType>(m, "LoraModuleType")
.value("INVALID", tr::LoraModule::ModuleType::kINVALID)
.value("ATTN_QKV", tr::LoraModule::ModuleType::kATTN_QKV)
.value("ATTN_Q", tr::LoraModule::ModuleType::kATTN_Q)
.value("ATTN_K", tr::LoraModule::ModuleType::kATTN_K)
.value("ATTN_V", tr::LoraModule::ModuleType::kATTN_V)
.value("ATTN_DENSE", tr::LoraModule::ModuleType::kATTN_DENSE)
.value("MLP_H_TO_4H", tr::LoraModule::ModuleType::kMLP_H_TO_4H)
.value("MLP_4H_TO_H", tr::LoraModule::ModuleType::kMLP_4H_TO_H)
.value("MLP_GATE", tr::LoraModule::ModuleType::kMLP_GATE)
.value("CROSS_ATTN_QKV", tr::LoraModule::ModuleType::kCROSS_ATTN_QKV)
.value("CROSS_ATTN_Q", tr::LoraModule::ModuleType::kCROSS_ATTN_Q)
.value("CROSS_ATTN_K", tr::LoraModule::ModuleType::kCROSS_ATTN_K)
.value("CROSS_ATTN_V", tr::LoraModule::ModuleType::kCROSS_ATTN_V)
.value("CROSS_ATTN_DENSE", tr::LoraModule::ModuleType::kCROSS_ATTN_DENSE)
.value("MOE_H_TO_4H", tr::LoraModule::ModuleType::kMOE_H_TO_4H)
.value("MOE_4H_TO_H", tr::LoraModule::ModuleType::kMOE_4H_TO_H)
.value("MOE_GATE", tr::LoraModule::ModuleType::kMOE_GATE)
.value("MOE_ROUTER", tr::LoraModule::ModuleType::kMOE_ROUTER)
.value("MLP_ROUTER", tr::LoraModule::ModuleType::kMLP_ROUTER)
.value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP);
py::class_<tr::LoraModule>(m, "LoraModule")
.def(py::init<tr::LoraModule::ModuleType, SizeType32, SizeType32, bool, bool, SizeType32, SizeType32>(),
py::arg("module_type"), py::arg("in_dim"), py::arg("out_dim"), py::arg("in_dim_first"),
py::arg("out_dim_first"), py::arg("in_tp_split_dim"), py::arg("out_tp_split_dim"))
.def_property_readonly("module_type", &tr::LoraModule::name)
.def_property_readonly("in_dim", &tr::LoraModule::inDim)
.def_property_readonly("out_dim", &tr::LoraModule::outDim)
.def_property_readonly("in_dim_first", &tr::LoraModule::inDimFirst)
.def_property_readonly("out_dim_first", &tr::LoraModule::outDimFirst)
.def_property_readonly("in_tp_split_dim", &tr::LoraModule::inTpSplitDim)
.def_property_readonly("out_tp_split_dim", &tr::LoraModule::outTpSplitDim)
.def_static("create_lora_modules", &tr::LoraModule::createLoraModules, py::arg("lora_module_names"),
py::arg("hidden_size"), py::arg("mlp_hidden_size"), py::arg("num_attention_heads"),
py::arg("num_kv_attention_heads"), py::arg("attention_head_size"), py::arg("tp_size") = 1,
py::arg("num_experts") = 0);
py::class_<tc::QuantMode>(m, "QuantMode")
.def_static("none", &tc::QuantMode::none)
.def_static("int4_weights", &tc::QuantMode::int4Weights)
.def_static("int8_weights", &tc::QuantMode::int8Weights)
.def_static("activations", &tc::QuantMode::activations)
.def_static("per_channel_scaling", &tc::QuantMode::perChannelScaling)
.def_static("per_token_scaling", &tc::QuantMode::perTokenScaling)
.def_static("per_group_scaling", &tc::QuantMode::perGroupScaling)
.def_static("int8_kv_cache", &tc::QuantMode::int8KvCache)
.def_static("fp8_kv_cache", &tc::QuantMode::fp8KvCache)
.def_static("fp8_qdq", &tc::QuantMode::fp8Qdq)
.def_property_readonly("value", &tc::QuantMode::value)
.def("is_set", &tc::QuantMode::isSet, py::arg("mode"))
.def_property_readonly("has_int4_weights", &tc::QuantMode::hasInt4Weights)
.def_property_readonly("has_int8_weights", &tc::QuantMode::hasInt8Weights)
.def_property_readonly("has_activations", &tc::QuantMode::hasActivations)
.def_property_readonly("has_per_channel_scaling", &tc::QuantMode::hasPerChannelScaling)
.def_property_readonly("has_per_token_scaling", &tc::QuantMode::hasPerTokenScaling)
.def_property_readonly("has_per_group_scaling", &tc::QuantMode::hasPerGroupScaling)
.def_property_readonly("has_static_activation_scaling", &tc::QuantMode::hasStaticActivationScaling)
.def_property_readonly("has_int8_kv_cache", &tc::QuantMode::hasInt8KvCache)
.def_property_readonly("has_fp4_kv_cache", &tc::QuantMode::hasFp4KvCache)
.def_property_readonly("has_fp8_kv_cache", &tc::QuantMode::hasFp8KvCache)
.def_property_readonly("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq)
.def_property_readonly("has_nvfp4", &tc::QuantMode::hasNvfp4)
.def_property_readonly("has_w4a8_mxfp4_fp8", &tc::QuantMode::hasW4a8Mxfp4Fp8)
.def_property_readonly("has_w4a8_mxfp4_mxfp8", &tc::QuantMode::hasW4a8Mxfp4Mxfp8)
.def_property_readonly("has_w4a16_mxfp4", &tc::QuantMode::hasW4a16Mxfp4)
.def_property_readonly("has_kv_cache_quant", &tc::QuantMode::hasKvCacheQuant)
.def_static("from_description", &tc::QuantMode::fromDescription, py::arg("quantize_weights"),
py::arg("quantize_activations"), py::arg("per_token"), py::arg("per_channel"), py::arg("per_group"),
py::arg("use_int4_weights"), py::arg("use_int8_kv_cache"), py::arg("use_fp8_kv_kache"),
py::arg("use_fp8_qdq"), py::arg("use_fp8_rowwise"), py::arg("use_w4a8_qserve"), py::arg("use_nvfp4"),
py::arg("use_fp8_block_scales"), py::arg("use_w4a8_mxfp4_fp8"), py::arg("use_w4a8_mxfp4_mxfp8"),
py::arg("use_w4a16_mxfp4"))
.def_static("use_smooth_quant", &tc::QuantMode::useSmoothQuant, py::arg("per_token") = false,
py::arg("per_channel") = false)
.def_static("use_weight_only", &tc::QuantMode::useWeightOnly, py::arg("use_int4_weights") = false,
py::arg("per_group") = false)
.def_static("from_quant_algo", &tc::QuantMode::fromQuantAlgo, py::arg("quant_algo") = py::none(),
py::arg("kv_cache_quant_algo") = py::none())
.def(py::self + py::self)
.def(py::self += py::self)
.def(py::self - py::self)
.def(py::self -= py::self)
.def(py::self == py::self)
.def(py::self != py::self);
py::class_<tr::ModelConfig>(m, "ModelConfig")
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, nvinfer1::DataType>(),
py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"),
py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type"))
.def_property_readonly("vocab_size", &tr::ModelConfig::getVocabSize)
.def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size"))
.def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1,
py::arg("pipeline_parallelism_rank") = 0)
.def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1,
py::arg("pipeline_parallelism_rank") = 0)
.def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1,
py::arg("pipeline_parallelism_rank") = 0)
.def("num_kv_heads", &tr::ModelConfig::getNbKvHeads, py::arg("layer_idx"))
.def("set_num_kv_heads", &tr::ModelConfig::setNbKvHeads, py::arg("num_kv_heads"))
.def_property_readonly("num_heads", &tr::ModelConfig::getNbHeads)
.def_property_readonly("hidden_size", &tr::ModelConfig::getHiddenSize)
.def_property_readonly("size_per_head", &tr::ModelConfig::getSizePerHead)
.def_property_readonly("data_type", &tr::ModelConfig::getDataType)
.def_property_readonly("speculative_decoding_mode", &tr::ModelConfig::getSpeculativeDecodingMode)
.def_property("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead)
.def_property(
"num_kv_heads_per_layer", &tr::ModelConfig::getNumKvHeadsPerLayer, &tr::ModelConfig::setNumKvHeadsPerLayer)
.def_property("use_gpt_attention_plugin",
py::overload_cast<>(&tr::ModelConfig::useGptAttentionPlugin, py::const_),
py::overload_cast<bool>(&tr::ModelConfig::useGptAttentionPlugin))
.def_property("use_packed_input", py::overload_cast<>(&tr::ModelConfig::usePackedInput, py::const_),
py::overload_cast<bool>(&tr::ModelConfig::usePackedInput))
.def_property("kv_cache_type", py::overload_cast<>(&tr::ModelConfig::getKVCacheType, py::const_),
py::overload_cast<tr::ModelConfig::KVCacheType>(&tr::ModelConfig::setKVCacheType))
.def_property("tokens_per_block", &tr::ModelConfig::getTokensPerBlock, &tr::ModelConfig::setTokensPerBlock)
.def_property("quant_mode", &tr::ModelConfig::getQuantMode, &tr::ModelConfig::setQuantMode)
.def_property_readonly("supports_inflight_batching", &tr::ModelConfig::supportsInflightBatching)
.def_property("max_batch_size", &tr::ModelConfig::getMaxBatchSize, &tr::ModelConfig::setMaxBatchSize)
.def_property("max_beam_width", &tr::ModelConfig::getMaxBeamWidth, &tr::ModelConfig::setMaxBeamWidth)
.def_property("max_input_len", &tr::ModelConfig::getMaxInputLen, &tr::ModelConfig::setMaxInputLen)
.def_property("max_seq_len", &tr::ModelConfig::getMaxSequenceLen, &tr::ModelConfig::setMaxSequenceLen)
.def_property("max_num_tokens", &tr::ModelConfig::getMaxNumTokens, &tr::ModelConfig::setMaxNumTokens)
.def_property("max_prompt_embedding_table_size", &tr::ModelConfig::getMaxPromptEmbeddingTableSize,
&tr::ModelConfig::setMaxPromptEmbeddingTableSize)
.def_property_readonly("use_prompt_tuning", &tr::ModelConfig::usePromptTuning)
.def_property_readonly("use_mrope", &tr::ModelConfig::useMrope)
.def_property("use_lora_plugin", py::overload_cast<>(&tr::ModelConfig::useLoraPlugin, py::const_),
py::overload_cast<bool>(&tr::ModelConfig::useLoraPlugin))
.def_property("layer_types", &tr::ModelConfig::getLayerTypes, &tr::ModelConfig::setLayerTypes)
.def_property("compute_context_logits", py::overload_cast<>(&tr::ModelConfig::computeContextLogits, py::const_),
py::overload_cast<bool>(&tr::ModelConfig::computeContextLogits))
.def_property("compute_generation_logits",
py::overload_cast<>(&tr::ModelConfig::computeGenerationLogits, py::const_),
py::overload_cast<bool>(&tr::ModelConfig::computeGenerationLogits))
.def_property("model_variant", &tr::ModelConfig::getModelVariant, &tr::ModelConfig::setModelVariant)
.def_property(
"use_cross_attention", &tr::ModelConfig::useCrossAttention, &tr::ModelConfig::setUseCrossAttention)
.def_property("lora_modules", &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules)
.def_property("max_lora_rank", &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank)
.def_property("mlp_hidden_size", &tr::ModelConfig::getMlpHiddenSize, &tr::ModelConfig::setMlpHiddenSize)
.def_property("size_per_head", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead);
py::class_<tr::WorldConfig>(m, "WorldConfig")
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32,
std::optional<std::vector<SizeType32>> const&, bool>(),
py::arg("tensor_parallelism") = 1, py::arg("pipeline_parallelism") = 1, py::arg("context_parallelism") = 1,
py::arg("rank") = 0, py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode,
py::arg("device_ids") = py::none(), py::arg("enable_attention_dp") = false)
.def_property_readonly("size", &tr::WorldConfig::getSize)
.def_property_readonly("tensor_parallelism", &tr::WorldConfig::getTensorParallelism)
.def_property_readonly("pipeline_parallelism", &tr::WorldConfig::getPipelineParallelism)
.def_property_readonly("context_parallelism", &tr::WorldConfig::getContextParallelism)
.def_property_readonly("is_tensor_parallel", &tr::WorldConfig::isTensorParallel)
.def_property_readonly("is_pipeline_parallel", &tr::WorldConfig::isPipelineParallel)
.def_property_readonly("is_context_parallel", &tr::WorldConfig::isContextParallel)
.def_property_readonly("rank", &tr::WorldConfig::getRank)
.def_property_readonly("local_rank", &tr::WorldConfig::getLocalRank)
.def_property_readonly("node_rank", &tr::WorldConfig::getNodeRank)
.def_property_readonly("gpus_per_node", &tr::WorldConfig::getGpusPerNode)
.def_property_readonly("gpus_per_group", &tr::WorldConfig::getGpusPerGroup)
.def_property_readonly("device", &tr::WorldConfig::getDevice)
.def_property_readonly("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank)
.def_property_readonly("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank)
.def_property_readonly("context_parallel_rank", &tr::WorldConfig::getContextParallelRank)
.def_property_readonly("enable_attention_dp", &tr::WorldConfig::enableAttentionDP)
.def_static("mpi",
py::overload_cast<SizeType32, std::optional<SizeType32>, std::optional<SizeType32>,
std::optional<SizeType32>, std::optional<std::vector<SizeType32>> const&, bool>(&tr::WorldConfig::mpi),
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, py::arg("tensor_parallelism") = py::none(),
py::arg("pipeline_parallelism") = py::none(), py::arg("context_parallelism") = py::none(),
py::arg("device_ids") = py::none(), py::arg("enable_attention_dp") = false);
auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> py::tuple
{
return py::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty,
config.presencePenalty, config.frequencyPenalty, config.promptIgnoreLength, config.topK, config.topP,
config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate,
config.lengthPenalty, config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences,
config.minP, config.beamWidthArray);
};
auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig
{
if (t.size() != 20)
{
throw std::runtime_error("Invalid SamplingConfig state!");
}
tr::SamplingConfig config;
config.beamWidth = t[0].cast<SizeType32>();
config.temperature = t[1].cast<OptVec<float>>();
config.minLength = t[2].cast<OptVec<SizeType32>>();
config.repetitionPenalty = t[3].cast<OptVec<float>>();
config.presencePenalty = t[4].cast<OptVec<float>>();
config.frequencyPenalty = t[5].cast<OptVec<float>>();
config.promptIgnoreLength = t[6].cast<OptVec<SizeType32>>();
config.topK = t[7].cast<OptVec<SizeType32>>();
config.topP = t[8].cast<OptVec<float>>();
config.randomSeed = t[9].cast<OptVec<uint64_t>>();
config.topPDecay = t[10].cast<OptVec<float>>();
config.topPMin = t[11].cast<OptVec<float>>();
config.topPResetIds = t[12].cast<OptVec<TokenIdType>>();
config.beamSearchDiversityRate = t[13].cast<OptVec<float>>();
config.lengthPenalty = t[14].cast<OptVec<float>>();
config.earlyStopping = t[15].cast<OptVec<SizeType32>>();
config.noRepeatNgramSize = t[16].cast<OptVec<SizeType32>>();
config.numReturnSequences = t[17].cast<SizeType32>();
config.minP = t[18].cast<OptVec<float>>();
config.beamWidthArray = t[19].cast<OptVec<std::vector<SizeType32>>>();
return config;
};
py::classh<tr::SamplingConfig>(m, "SamplingConfig")
.def(py::init<SizeType32>(), py::arg("beam_width") = 1)
.def(py::init<tle::SamplingConfig, std::optional<tle::ExternalDraftTokensConfig>>(),
py::arg("executor_sample_config"), py::arg("external_draft_tokens_config") = std::nullopt)
.def_readwrite("beam_width", &tr::SamplingConfig::beamWidth)
.def_readwrite("temperature", &tr::SamplingConfig::temperature)
.def_readwrite("min_length", &tr::SamplingConfig::minLength)
.def_readwrite("repetition_penalty", &tr::SamplingConfig::repetitionPenalty)
.def_readwrite("presence_penalty", &tr::SamplingConfig::presencePenalty)
.def_readwrite("frequency_penalty", &tr::SamplingConfig::frequencyPenalty)
.def_readwrite("prompt_ignore_length", &tr::SamplingConfig::promptIgnoreLength)
.def_readwrite("top_k", &tr::SamplingConfig::topK)
.def_readwrite("top_p", &tr::SamplingConfig::topP)
.def_readwrite("random_seed", &tr::SamplingConfig::randomSeed)
.def_readwrite("top_p_decay", &tr::SamplingConfig::topPDecay)
.def_readwrite("top_p_min", &tr::SamplingConfig::topPMin)
.def_readwrite("top_p_reset_ids", &tr::SamplingConfig::topPResetIds)
.def_readwrite("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate)
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty)
.def_readwrite("early_stopping", &tr::SamplingConfig::earlyStopping)
.def_readwrite("no_repeat_ngram_size", &tr::SamplingConfig::noRepeatNgramSize)
.def_readwrite("num_return_sequences", &tr::SamplingConfig::numReturnSequences)
.def_readwrite("min_p", &tr::SamplingConfig::minP)
.def_readwrite("beam_width_array", &tr::SamplingConfig::beamWidthArray)
.def_readwrite("normalize_log_probs", &tr::SamplingConfig::normalizeLogProbs)
.def(py::pickle(SamplingConfigGetState, SamplingConfigSetState))
.def("__eq__", &tr::SamplingConfig::operator==);
m.def("make_sampling_config", &makeSamplingConfig, py::arg("configs"));
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
.def(py::init<std::string, std::string, std::string, SizeType32, SizeType32, SizeType32, SizeType32,
tr::ModelConfig, std::optional<tr::RuntimeDefaults>>(),
py::arg("name"), py::arg("version"), py::arg("precision"), py::arg("tensor_parallelism"),
py::arg("pipeline_parallelism"), py::arg("context_parallelism"), py::arg("gpus_per_node"),
py::arg("model_config"), py::arg("runtime_defaults") = py::none())
.def_static("parse", py::overload_cast<std::string const&>(&tr::GptJsonConfig::parse), py::arg("json"))
.def_static(
"parse_file", py::overload_cast<std::filesystem::path const&>(&tr::GptJsonConfig::parse), py::arg("path"))
.def_property_readonly("model_config", &tr::GptJsonConfig::getModelConfig)
.def_property_readonly("name", &tr::GptJsonConfig::getName)
.def_property_readonly("version", &tr::GptJsonConfig::getVersion)
.def_property_readonly("precision", &tr::GptJsonConfig::getPrecision)
.def_property_readonly("tensor_parallelism", &tr::GptJsonConfig::getTensorParallelism)
.def_property_readonly("pipeline_parallelism", &tr::GptJsonConfig::getPipelineParallelism)
.def_property_readonly("context_parallelism", &tr::GptJsonConfig::getContextParallelism)
.def_property_readonly("gpus_per_node", &tr::GptJsonConfig::getGpusPerNode)
.def_property_readonly("world_size", &tr::GptJsonConfig::getWorldSize)
.def_property_readonly("runtime_defaults", &tr::GptJsonConfig::getRuntimeDefaults)
.def("engine_filename",
py::overload_cast<tr::WorldConfig const&, std::string const&>(
&tr::GptJsonConfig::engineFilename, py::const_),
py::arg("world_config"), py::arg("model"))
.def("engine_filename",
py::overload_cast<tr::WorldConfig const&>(&tr::GptJsonConfig::engineFilename, py::const_),
py::arg("world_config"));
py::enum_<tb::LlmRequestState>(m, "LlmRequestState")
.value("UNKNOWN", tb::LlmRequestState::kUNKNOWN)
.value("ENCODER_INIT", tb::LlmRequestState::kENCODER_INIT)
.value("CONTEXT_INIT", tb::LlmRequestState::kCONTEXT_INIT)
.value("GENERATION_IN_PROGRESS", tb::LlmRequestState::kGENERATION_IN_PROGRESS)
.value("GENERATION_TO_COMPLETE", tb::LlmRequestState::kGENERATION_TO_COMPLETE)
.value("GENERATION_COMPLETE", tb::LlmRequestState::kGENERATION_COMPLETE)
.value("DISAGG_GENERATION_INIT", tb::LlmRequestState::kDISAGG_GENERATION_INIT)
.value("DISAGG_CONTEXT_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS)
.value("DISAGG_CONTEXT_COMPLETE", tb::LlmRequestState::kDISAGG_CONTEXT_COMPLETE)
.value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS)
.value("DISAGG_TRANS_ERROR", tb::LlmRequestState::kDISAGG_TRANS_ERROR)
.value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE)
.value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS)
.value("DISAGG_CONTEXT_WAIT_SCHEDULER", tb::LlmRequestState::kDISAGG_CONTEXT_WAIT_SCHEDULER)
.value("DISAGG_GENERATION_WAIT_TOKENS", tb::LlmRequestState::kDISAGG_GENERATION_WAIT_TOKENS);
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)
.def_property_readonly("gpu", &tr::MemoryCounters::getGpu)
.def_property_readonly("cpu", &tr::MemoryCounters::getCpu)
.def_property_readonly("pinned", &tr::MemoryCounters::getPinned)
.def_property_readonly("uvm", &tr::MemoryCounters::getUVM);
tensorrt_llm::pybind::process_group::initBindings(mInternalProcessGroup);
tpb::Buffers::initBindings(mInternalBatchManager);
tensorrt_llm::pybind::runtime::initBindings(mInternalRuntime);
tensorrt_llm::pybind::testing::initBindings(mInternalTesting);
tpb::initBindings(mInternalBatchManager);
tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager);
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager);
tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager);
tb::CacheTransceiverBindings::initBindings(mInternalBatchManager);
tb::kv_cache_manager_v2::KVCacheManagerV2UtilsBindings::initBindings(mInternalBatchManagerKvCacheV2Utils);
auto mInternalAlgorithms = mInternal.def_submodule("algorithms", "Algorithms internal bindings");
tpb::algorithms::initBindings(mInternalAlgorithms);
auto mUserbuffers = mInternal.def_submodule("userbuffers", "User buffers internal bindings");
tensorrt_llm::kernels::userbuffers::UserBufferBindings::initBindings(mUserbuffers);
// NVLS allocators
py::class_<tr::IpcNvlsHandle>(m, "IpcNvlsHandle")
.def(py::init<>())
.def_readwrite("uc_ptr", &tr::IpcNvlsHandle::uc_ptr)
.def_readwrite("mc_ptr", &tr::IpcNvlsHandle::mc_ptr)
.def_readwrite("size", &tr::IpcNvlsHandle::size)
.def("get_ipc_ptrs",
[](tr::IpcNvlsHandle& self) { return reinterpret_cast<uintptr_t>(self.ipc_uc_ptrs.data()); });
m.def("ipc_nvls_allocate", &tr::ipcNvlsAllocate, py::return_value_policy::reference);
m.def("ipc_nvls_free", &tr::ipcNvlsFree);
m.def("ipc_nvls_supported", &tr::ipcNvlsSupported);
m.def("steady_clock_now", []() { return std::chrono::steady_clock::now(); });
}

View File

@ -1,94 +0,0 @@
/*
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace PybindUtils
{
namespace py = pybind11;
template <typename T>
void bindList(py::module& m, std::string const& name)
{
py::class_<T>(m, name.c_str())
.def(py::init())
.def("push_back", [](T& lst, const typename T::value_type& value) { lst.push_back(value); })
.def("pop_back", [](T& lst) { lst.pop_back(); })
.def("push_front", [](T& lst, const typename T::value_type& value) { lst.push_front(value); })
.def("pop_front", [](T& lst) { lst.pop_front(); })
.def("__len__", [](T const& lst) { return lst.size(); })
.def(
"__iter__", [](T& lst) { return py::make_iterator(lst.begin(), lst.end()); }, py::keep_alive<0, 1>())
.def("__getitem__",
[](T const& lst, size_t index)
{
if (index >= lst.size())
throw py::index_error();
auto it = lst.begin();
std::advance(it, index);
return *it;
})
.def("__setitem__",
[](T& lst, size_t index, const typename T::value_type& value)
{
if (index >= lst.size())
throw py::index_error();
auto it = lst.begin();
std::advance(it, index);
*it = value;
});
}
template <typename T>
void bindSet(py::module& m, std::string const& name)
{
py::class_<T>(m, name.c_str())
.def(py::init())
.def("clear", &T::clear)
.def("size", &T::size)
.def("insert", [](T& s, typename T::value_type const& value) { s.insert(value); })
.def("erase", py::overload_cast<typename T::value_type const&>(&T::erase))
.def("__len__", [](T const& lst) { return lst.size(); })
.def("__contains__", [](T const& s, typename T::value_type x) { return s.find(x) != s.end(); })
.def(
"__iter__", [](T& s) { return py::make_iterator(s.begin(), s.end()); }, py::keep_alive<0, 1>())
.def("__eq__", [](T const& s, T const& other) { return s == other; })
.def(py::pickle(
[](T const& s) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(std::vector<typename T::value_type>(s.begin(), s.end()));
},
[](py::tuple t) { // __setstate__
if (t.size() != 1)
throw std::runtime_error("Invalid state!");
/* Create a new C++ instance */
T s;
/* Assign any additional state */
auto state_list = t[0].cast<std::vector<typename T::value_type>>();
for (auto& item : state_list)
{
s.insert(item);
}
return s;
}));
}
} // namespace PybindUtils

View File

@ -1,265 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/detail/descr.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
// Pybind requires to have a central include in order for type casters to work.
// Opaque bindings add a type caster, so they have the same requirement.
// See the warning in https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html
// Opaque bindings
PYBIND11_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet)
PYBIND11_MAKE_OPAQUE(std::vector<tensorrt_llm::batch_manager::SlotDecoderBuffers>)
// Custom casters
namespace PYBIND11_NAMESPACE
{
namespace detail
{
template <typename T>
struct type_caster<tensorrt_llm::common::OptionalRef<T>>
{
using value_conv = make_caster<T>;
PYBIND11_TYPE_CASTER(tensorrt_llm::common::OptionalRef<T>, value_conv::name);
bool load(handle src, bool convert)
{
if (src.is_none())
{
// If the Python object is None, create an empty OptionalRef
value = tensorrt_llm::common::OptionalRef<T>();
return true;
}
value_conv conv;
if (!conv.load(src, convert))
return false;
// Create an OptionalRef with a reference to the converted value
value = tensorrt_llm::common::OptionalRef<T>(conv);
return true;
}
static handle cast(tensorrt_llm::common::OptionalRef<T> const& src, return_value_policy policy, handle parent)
{
if (!src.has_value())
return none().release();
return value_conv::cast(*src, policy, parent);
}
};
template <typename T>
struct PathCaster
{
private:
static PyObject* unicode_from_fs_native(std::string const& w)
{
return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size()));
}
static PyObject* unicode_from_fs_native(std::wstring const& w)
{
return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size()));
}
public:
static handle cast(T const& path, return_value_policy, handle)
{
if (auto py_str = unicode_from_fs_native(path.native()))
{
return module_::import("pathlib").attr("Path")(reinterpret_steal<object>(py_str)).release();
}
return nullptr;
}
bool load(handle handle, bool)
{
PyObject* native = nullptr;
if constexpr (std::is_same_v<typename T::value_type, char>)
{
if (PyUnicode_FSConverter(handle.ptr(), &native) != 0)
{
if (auto* c_str = PyBytes_AsString(native))
{
// AsString returns a pointer to the internal buffer, which
// must not be free'd.
value = c_str;
}
}
}
else if constexpr (std::is_same_v<typename T::value_type, wchar_t>)
{
if (PyUnicode_FSDecoder(handle.ptr(), &native) != 0)
{
if (auto* c_str = PyUnicode_AsWideCharString(native, nullptr))
{
// AsWideCharString returns a new string that must be free'd.
value = c_str; // Copies the string.
PyMem_Free(c_str);
}
}
}
Py_XDECREF(native);
if (PyErr_Occurred())
{
PyErr_Clear();
return false;
}
return true;
}
PYBIND11_TYPE_CASTER(T, const_name("os.PathLike"));
};
template <>
struct type_caster<std::filesystem::path> : public PathCaster<std::filesystem::path>
{
};
template <>
class type_caster<tensorrt_llm::executor::StreamPtr>
{
public:
PYBIND11_TYPE_CASTER(tensorrt_llm::executor::StreamPtr, _("int"));
bool load([[maybe_unused]] handle src, bool)
{
auto stream_ptr = src.cast<uintptr_t>();
value = std::make_shared<tensorrt_llm::runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
static handle cast(
tensorrt_llm::executor::StreamPtr const& src, return_value_policy /* policy */, handle /* parent */)
{
// Return cudaStream_t as integer.
return PyLong_FromVoidPtr(src->get());
}
};
template <>
struct type_caster<tensorrt_llm::executor::Tensor>
{
public:
PYBIND11_TYPE_CASTER(tensorrt_llm::executor::Tensor, _("torch.Tensor"));
// Convert PyObject(torch.Tensor) -> tensorrt_llm::executor::Tensor
bool load(handle src, bool)
{
PyObject* obj = src.ptr();
if (THPVariable_Check(obj))
{
at::Tensor const& t = THPVariable_Unpack(obj);
value = tensorrt_llm::executor::detail::ofITensor(tensorrt_llm::runtime::TorchView::of(t));
return true;
}
return false;
}
// Convert tensorrt_llm::executor::Tensor -> PyObject(torch.Tensor)
static handle cast(tensorrt_llm::executor::Tensor const& src, return_value_policy /* policy */, handle /* parent */)
{
return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(tensorrt_llm::executor::detail::toITensor(src)));
}
};
template <>
struct type_caster<tensorrt_llm::runtime::ITensor::SharedPtr>
{
public:
PYBIND11_TYPE_CASTER(tensorrt_llm::runtime::ITensor::SharedPtr, _("torch.Tensor"));
// Convert PyObject(torch.Tensor) -> tensorrt_llm::runtime::ITensor::SharedPtr
bool load(handle src, bool)
{
PyObject* obj = src.ptr();
if (THPVariable_Check(obj))
{
at::Tensor const& t = THPVariable_Unpack(obj);
value = std::move(tensorrt_llm::runtime::TorchView::of(t));
return true;
}
return false;
}
// Convert tensorrt_llm::runtime::ITensor::SharedPtr -> PyObject(torch.Tensor)
static handle cast(
tensorrt_llm::runtime::ITensor::SharedPtr const& src, return_value_policy /* policy */, handle /* parent */)
{
if (src == nullptr)
{
return none().release();
}
return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(src));
}
};
template <>
struct type_caster<tensorrt_llm::runtime::ITensor::SharedConstPtr>
{
public:
PYBIND11_TYPE_CASTER(tensorrt_llm::runtime::ITensor::SharedConstPtr, _("torch.Tensor"));
// Convert PyObject(torch.Tensor) -> tensorrt_llm::runtime::ITensor::SharedConstPtr
bool load(handle src, bool)
{
PyObject* obj = src.ptr();
if (THPVariable_Check(obj))
{
at::Tensor const& t = THPVariable_Unpack(obj);
value = std::move(tensorrt_llm::runtime::TorchView::of(t));
return true;
}
return false;
}
// Convert tensorrt_llm::runtime::ITensor::SharedConstPtr -> PyObject(torch.Tensor)
static handle cast(tensorrt_llm::runtime::ITensor::SharedConstPtr const& src, return_value_policy /* policy */,
handle /* parent */)
{
if (src == nullptr)
{
return none().release();
}
return THPVariable_Wrap(tensorrt_llm::runtime::Torch::tensor(
reinterpret_cast<tensorrt_llm::runtime::ITensor::SharedPtr const&>(src)));
}
};
} // namespace detail
} // namespace PYBIND11_NAMESPACE

View File

@ -1,67 +0,0 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tllmExceptions.h"
namespace py = pybind11;
namespace tc = tensorrt_llm::common;
namespace tensorrt_llm::pybind::common
{
void initExceptionsBindings(py::module_& m)
{
// Bind the RequestErrorCode enum
py::enum_<tc::RequestErrorCode>(m, "RequestErrorCode")
.value("UNKNOWN_ERROR", tc::RequestErrorCode::kUNKNOWN_ERROR)
.value("NETWORK_ERROR", tc::RequestErrorCode::kNETWORK_ERROR)
.export_values();
// Create the RequestSpecificException Python exception class
static PyObject* request_specific_exc
= PyErr_NewException("tensorrt_llm.RequestSpecificException", nullptr, nullptr);
// Add attributes to the Python exception class
py::handle(request_specific_exc).attr("request_id") = py::none();
py::handle(request_specific_exc).attr("error_code") = py::none();
m.add_object("RequestSpecificException", py::handle(request_specific_exc));
// Register exception translator to convert C++ exceptions to Python
py::register_exception_translator(
[](std::exception_ptr p)
{
try
{
if (p)
std::rethrow_exception(p);
}
catch (const tc::RequestSpecificException& e)
{
// Create a Python exception with the request ID and error code information
py::object msg = py::str(e.what());
py::object inst = py::reinterpret_steal<py::object>(
PyObject_CallFunctionObjArgs(request_specific_exc, msg.ptr(), nullptr));
inst.attr("request_id") = py::cast(e.getRequestId());
inst.attr("error_code") = py::cast(e.getErrorCode());
PyErr_SetObject(request_specific_exc, inst.ptr());
}
});
}
} // namespace tensorrt_llm::pybind::common

View File

@ -1,32 +0,0 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/tllmException.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace tc = tensorrt_llm::common;
namespace tensorrt_llm::pybind::common
{
/// @brief Bind RequestSpecificException and related types to Python
/// @param m The pybind11 module to bind to
void initExceptionsBindings(py::module_& m);
} // namespace tensorrt_llm::pybind::common

View File

@ -1,280 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include "executor.h"
#include "executorConfig.h"
#include "request.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/types.h"
#include <pybind11/cast.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <optional>
namespace py = pybind11;
namespace tle = tensorrt_llm::executor;
using SizeType32 = tle::SizeType32;
namespace tensorrt_llm::pybind::executor
{
template <typename T>
void instantiateEventDiff(pybind11::module& m, std::string const& name)
{
py::class_<tle::KVCacheEventDiff<T>>(m, ("KVCacheEventDiff" + name).c_str())
.def_readonly("old_value", &tle::KVCacheEventDiff<T>::oldValue)
.def_readonly("new_value", &tle::KVCacheEventDiff<T>::newValue);
}
void initBindings(pybind11::module_& m)
{
m.attr("__version__") = tle::version();
py::enum_<tle::ModelType>(m, "ModelType")
.value("DECODER_ONLY", tle::ModelType::kDECODER_ONLY)
.value("ENCODER_ONLY", tle::ModelType::kENCODER_ONLY)
.value("ENCODER_DECODER", tle::ModelType::kENCODER_DECODER);
auto decodingModeGetstate = [](tle::DecodingMode const& self) { return py::make_tuple(self.getState()); };
auto decodingModeSetstate = [](py::tuple const& state)
{
if (state.size() != 1)
{
throw std::runtime_error("Invalid state!");
}
return tle::DecodingMode(state[0].cast<tle::DecodingMode::UnderlyingType>());
};
py::class_<tle::DecodingMode>(m, "DecodingMode")
.def("Auto", &tle::DecodingMode::Auto)
.def("TopK", &tle::DecodingMode::TopK)
.def("TopP", &tle::DecodingMode::TopP)
.def("TopKTopP", &tle::DecodingMode::TopKTopP)
.def("BeamSearch", &tle::DecodingMode::BeamSearch)
.def("Medusa", &tle::DecodingMode::Medusa)
.def("Lookahead", &tle::DecodingMode::Lookahead)
.def("ExplicitDraftTokens", &tle::DecodingMode::ExplicitDraftTokens)
.def("Eagle", &tle::DecodingMode::Eagle)
.def("isAuto", &tle::DecodingMode::isAuto)
.def("isTopK", &tle::DecodingMode::isTopK)
.def("isTopP", &tle::DecodingMode::isTopP)
.def("isTopKorTopP", &tle::DecodingMode::isTopKorTopP)
.def("isTopKandTopP", &tle::DecodingMode::isTopKandTopP)
.def("isBeamSearch", &tle::DecodingMode::isBeamSearch)
.def("isMedusa", &tle::DecodingMode::isMedusa)
.def("isLookahead", &tle::DecodingMode::isLookahead)
.def("isExplicitDraftTokens", &tle::DecodingMode::isExplicitDraftTokens)
.def("isEagle", &tle::DecodingMode::isEagle)
.def("useVariableBeamWidthSearch", &tle::DecodingMode::useVariableBeamWidthSearch)
.def_property_readonly("name", &tle::DecodingMode::getName)
.def(py::pickle(decodingModeGetstate, decodingModeSetstate));
py::enum_<tle::CapacitySchedulerPolicy>(m, "CapacitySchedulerPolicy")
.value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION)
.value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT)
.value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH);
py::enum_<tle::ContextChunkingPolicy>(m, "ContextChunkingPolicy")
.value("EQUAL_PROGRESS", tle::ContextChunkingPolicy::kEQUAL_PROGRESS)
.value("FIRST_COME_FIRST_SERVED", tle::ContextChunkingPolicy::kFIRST_COME_FIRST_SERVED);
py::enum_<tle::CommunicationType>(m, "CommunicationType").value("MPI", tle::CommunicationType::kMPI);
py::enum_<tle::CommunicationMode>(m, "CommunicationMode")
.value("LEADER", tle::CommunicationMode::kLEADER)
.value("ORCHESTRATOR", tle::CommunicationMode::kORCHESTRATOR);
py::class_<tle::KvCacheStats>(m, "KvCacheStats")
.def(py::init<>())
.def_readwrite("max_num_blocks", &tle::KvCacheStats::maxNumBlocks)
.def_readwrite("free_num_blocks", &tle::KvCacheStats::freeNumBlocks)
.def_readwrite("used_num_blocks", &tle::KvCacheStats::usedNumBlocks)
.def_readwrite("tokens_per_block", &tle::KvCacheStats::tokensPerBlock)
.def_readwrite("alloc_total_blocks", &tle::KvCacheStats::allocTotalBlocks)
.def_readwrite("alloc_new_blocks", &tle::KvCacheStats::allocNewBlocks)
.def_readwrite("reused_blocks", &tle::KvCacheStats::reusedBlocks)
.def_readwrite("missed_blocks", &tle::KvCacheStats::missedBlocks)
.def_readwrite("cache_hit_rate", &tle::KvCacheStats::cacheHitRate);
py::class_<tle::StaticBatchingStats>(m, "StaticBatchingStats")
.def(py::init<>())
.def_readwrite("num_scheduled_requests", &tle::StaticBatchingStats::numScheduledRequests)
.def_readwrite("num_context_requests", &tle::StaticBatchingStats::numContextRequests)
.def_readwrite("num_ctx_tokens", &tle::StaticBatchingStats::numCtxTokens)
.def_readwrite("num_gen_tokens", &tle::StaticBatchingStats::numGenTokens)
.def_readwrite("empty_gen_slots", &tle::StaticBatchingStats::emptyGenSlots);
py::class_<tle::InflightBatchingStats>(m, "InflightBatchingStats")
.def(py::init<>())
.def_readwrite("num_scheduled_requests", &tle::InflightBatchingStats::numScheduledRequests)
.def_readwrite("num_context_requests", &tle::InflightBatchingStats::numContextRequests)
.def_readwrite("num_gen_requests", &tle::InflightBatchingStats::numGenRequests)
.def_readwrite("num_paused_requests", &tle::InflightBatchingStats::numPausedRequests)
.def_readwrite("num_ctx_tokens", &tle::InflightBatchingStats::numCtxTokens)
.def_readwrite("micro_batch_id", &tle::InflightBatchingStats::microBatchId)
.def_readwrite("avg_num_decoded_tokens_per_iter", &tle::InflightBatchingStats::avgNumDecodedTokensPerIter);
py::class_<tle::SpecDecodingStats>(m, "SpecDecodingStats")
.def(py::init<>())
.def_readwrite("num_draft_tokens", &tle::SpecDecodingStats::numDraftTokens)
.def_readwrite("num_accepted_tokens", &tle::SpecDecodingStats::numAcceptedTokens)
.def_readwrite("num_requests_with_draft_tokens", &tle::SpecDecodingStats::numRequestsWithDraftTokens)
.def_readwrite("acceptance_length", &tle::SpecDecodingStats::acceptanceLength)
.def_readwrite("iter_latency_ms", &tle::SpecDecodingStats::iterLatencyMS)
.def_readwrite("draft_overhead", &tle::SpecDecodingStats::draftOverhead);
py::class_<tle::IterationStats>(m, "IterationStats")
.def(py::init<>())
.def_readwrite("timestamp", &tle::IterationStats::timestamp)
.def_readwrite("iter", &tle::IterationStats::iter)
.def_readwrite("iter_latency_ms", &tle::IterationStats::iterLatencyMS)
.def_readwrite("new_active_requests_queue_latency_ms", &tle::IterationStats::newActiveRequestsQueueLatencyMS)
.def_readwrite("num_new_active_requests", &tle::IterationStats::numNewActiveRequests)
.def_readwrite("num_active_requests", &tle::IterationStats::numActiveRequests)
.def_readwrite("num_queued_requests", &tle::IterationStats::numQueuedRequests)
.def_readwrite("num_completed_requests", &tle::IterationStats::numCompletedRequests)
.def_readwrite("max_num_active_requests", &tle::IterationStats::maxNumActiveRequests)
.def_readwrite("gpu_mem_usage", &tle::IterationStats::gpuMemUsage)
.def_readwrite("cpu_mem_usage", &tle::IterationStats::cpuMemUsage)
.def_readwrite("pinned_mem_usage", &tle::IterationStats::pinnedMemUsage)
.def_readwrite("kv_cache_stats", &tle::IterationStats::kvCacheStats)
.def_readwrite("cross_kv_cache_stats", &tle::IterationStats::crossKvCacheStats)
.def_readwrite("static_batching_stats", &tle::IterationStats::staticBatchingStats)
.def_readwrite("inflight_batching_stats", &tle::IterationStats::inflightBatchingStats)
.def_readwrite("specdec_stats", &tle::IterationStats::specDecodingStats)
.def("to_json_str",
[](tle::IterationStats const& iterationStats)
{ return tle::JsonSerialization::toJsonStr(iterationStats); });
py::class_<tle::DebugTensorsPerIteration>(m, "DebugTensorsPerIteration")
.def(py::init<>())
.def_readwrite("iter", &tle::DebugTensorsPerIteration::iter)
.def_readwrite("debug_tensors", &tle::DebugTensorsPerIteration::debugTensors);
py::enum_<tle::RequestStage>(m, "RequestStage")
.value("QUEUED", tle::RequestStage::kQUEUED)
.value("ENCODER_IN_PROGRESS", tle::RequestStage::kENCODER_IN_PROGRESS)
.value("CONTEXT_IN_PROGRESS", tle::RequestStage::kCONTEXT_IN_PROGRESS)
.value("GENERATION_IN_PROGRESS", tle::RequestStage::kGENERATION_IN_PROGRESS)
.value("GENERATION_COMPLETE", tle::RequestStage::kGENERATION_COMPLETE);
py::class_<tle::DisServingRequestStats>(m, "DisServingRequestStats")
.def(py::init<>())
.def_readwrite("kv_cache_transfer_ms", &tle::DisServingRequestStats::kvCacheTransferMS)
.def_readwrite("kv_cache_size", &tle::DisServingRequestStats::kvCacheSize);
py::class_<tle::RequestStats>(m, "RequestStats")
.def(py::init<>())
.def_readwrite("id", &tle::RequestStats::id)
.def_readwrite("stage", &tle::RequestStats::stage)
.def_readwrite("context_prefill_position", &tle::RequestStats::contextPrefillPosition)
.def_readwrite("num_generated_tokens", &tle::RequestStats::numGeneratedTokens)
.def_readwrite("avg_num_decoded_tokens_per_iter", &tle::RequestStats::avgNumDecodedTokensPerIter)
.def_readwrite("scheduled", &tle::RequestStats::scheduled)
.def_readwrite("paused", &tle::RequestStats::paused)
.def_readwrite("dis_serving_stats", &tle::RequestStats::disServingStats)
.def_readwrite("alloc_total_blocks_per_request", &tle::RequestStats::allocTotalBlocksPerRequest)
.def_readwrite("alloc_new_blocks_per_request", &tle::RequestStats::allocNewBlocksPerRequest)
.def_readwrite("reused_blocks_per_request", &tle::RequestStats::reusedBlocksPerRequest)
.def_readwrite("missed_blocks_per_request", &tle::RequestStats::missedBlocksPerRequest)
.def_readwrite("kv_cache_hit_rate_per_request", &tle::RequestStats::kvCacheHitRatePerRequest)
.def("to_json_str",
[](tle::RequestStats const& iterationStats) { return tle::JsonSerialization::toJsonStr(iterationStats); });
py::class_<tle::RequestStatsPerIteration>(m, "RequestStatsPerIteration")
.def(py::init<>())
.def_readwrite("iter", &tle::RequestStatsPerIteration::iter)
.def_readwrite("request_stats", &tle::RequestStatsPerIteration::requestStats)
.def("to_json_str",
[](tle::RequestStatsPerIteration const& iterationStats)
{ return tle::JsonSerialization::toJsonStr(iterationStats); });
py::module_ executor_kv_cache = m.def_submodule("kv_cache", "Executor KV Cache Manager");
py::class_<tle::KVCacheCreatedData>(executor_kv_cache, "KVCacheCreatedData")
.def_readonly("num_blocks_per_cache_level", &tle::KVCacheCreatedData::numBlocksPerCacheLevel);
py::class_<tensorrt_llm::runtime::UniqueToken>(executor_kv_cache, "UniqueToken")
.def_readonly("token_id", &tensorrt_llm::runtime::UniqueToken::tokenId)
.def_readonly("token_extra_id", &tensorrt_llm::runtime::UniqueToken::tokenExtraId);
py::class_<tle::KVCacheStoredBlockData>(executor_kv_cache, "KVCacheStoredBlockData")
.def_readonly("block_hash", &tle::KVCacheStoredBlockData::blockHash)
.def_readonly("tokens", &tle::KVCacheStoredBlockData::tokens)
.def_readonly("lora_id", &tle::KVCacheStoredBlockData::loraId)
.def_readonly("cache_level", &tle::KVCacheStoredBlockData::cacheLevel)
.def_readonly("priority", &tle::KVCacheStoredBlockData::priority)
.def_property_readonly("mm_keys",
[](tle::KVCacheStoredBlockData const& self)
{
// Convert std::vector<MmKey> to Python list of tuples (bytes, int)
// MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
py::list result;
for (auto const& mmKey : self.mmKeys)
{
auto const& hashArray = mmKey.first;
auto offset = mmKey.second;
py::bytes hashBytes(reinterpret_cast<char const*>(hashArray.data()), hashArray.size());
result.append(py::make_tuple(hashBytes, offset));
}
return result;
});
py::class_<tle::KVCacheStoredData>(executor_kv_cache, "KVCacheStoredData")
.def_readonly("parent_hash", &tle::KVCacheStoredData::parentHash)
.def_readonly("blocks", &tle::KVCacheStoredData::blocks);
py::class_<tle::KVCacheRemovedData>(executor_kv_cache, "KVCacheRemovedData")
.def_readonly("block_hashes", &tle::KVCacheRemovedData::blockHashes);
instantiateEventDiff<SizeType32>(executor_kv_cache, "Int");
py::class_<tle::KVCacheUpdatedData>(executor_kv_cache, "KVCacheUpdatedData")
.def_readonly("block_hash", &tle::KVCacheUpdatedData::blockHash)
.def_readonly("cache_level", &tle::KVCacheUpdatedData::cacheLevel)
.def_readonly("priority", &tle::KVCacheUpdatedData::priority);
py::class_<tle::KVCacheEvent>(executor_kv_cache, "KVCacheEvent")
.def_readonly("event_id", &tle::KVCacheEvent::eventId)
.def_readonly("data", &tle::KVCacheEvent::data)
.def_readonly("window_size", &tle::KVCacheEvent::windowSize)
.def_readonly("attention_dp_rank", &tle::KVCacheEvent::attentionDpRank);
py::class_<tle::KVCacheEventManager, std::shared_ptr<tle::KVCacheEventManager>>(
executor_kv_cache, "KVCacheEventManager")
.def(
"get_latest_events",
[](tle::KVCacheEventManager& self, std::optional<double> timeout_ms = std::nullopt)
{
if (timeout_ms)
{
return self.getLatestEvents(std::chrono::milliseconds(static_cast<int64_t>(*timeout_ms)));
}
return self.getLatestEvents(std::nullopt);
},
py::arg("timeout_ms") = std::nullopt);
tensorrt_llm::pybind::executor::initRequestBindings(m);
tensorrt_llm::pybind::executor::initConfigBindings(m);
tensorrt_llm::pybind::executor::Executor::initBindings(m);
}
} // namespace tensorrt_llm::pybind::executor

View File

@ -1,29 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::executor
{
// Register bindings for executor API.
void initBindings(pybind11::module_& m);
} // namespace tensorrt_llm::pybind::executor

View File

@ -1,191 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "executor.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/tensor.h"
#include <pybind11/chrono.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace tle = tensorrt_llm::executor;
namespace
{
tle::Tensor numpyToTensor(py::array const& array)
{
auto npDtype = array.dtype();
tle::DataType dtype;
if (npDtype.is(py::dtype("float16")))
{
dtype = tle::DataType::kFP16;
}
else if (npDtype.is(py::dtype("float32")))
{
dtype = tle::DataType::kFP32;
}
else if (npDtype.is(py::dtype("int8")))
{
dtype = tle::DataType::kINT8;
}
else if (npDtype.is(py::dtype("int32")))
{
dtype = tle::DataType::kINT32;
}
else if (npDtype.is(py::dtype("int64")))
{
dtype = tle::DataType::kINT64;
}
else if (npDtype.attr("kind").cast<std::string>() == "V" && npDtype.attr("itemsize").cast<int>() == 1
&& npDtype.attr("metadata")["dtype"].cast<std::string>() == "float8")
{
dtype = tle::DataType::kFP8;
}
else if (npDtype.attr("kind").cast<std::string>() == "V" && npDtype.attr("itemsize").cast<int>() == 2
&& npDtype.attr("metadata")["dtype"].cast<std::string>() == "bfloat16")
{
dtype = tle::DataType::kBF16;
}
else
{
TLLM_THROW("Unsupported numpy dtype: " + npDtype.attr("name").cast<std::string>());
}
tle::Shape shape(array.shape(), array.ndim());
return tle::Tensor::of(dtype, const_cast<void*>(array.data()), shape);
}
} // namespace
namespace tensorrt_llm::pybind::executor
{
Executor::Executor(
std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig)
{
mExecutor = std::make_unique<tle::Executor>(modelPath, modelType, executorConfig);
}
Executor::Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath,
tle::ModelType modelType, tle::ExecutorConfig const& executorConfig)
{
mExecutor = std::make_unique<tle::Executor>(encoderModelPath, decoderModelPath, modelType, executorConfig);
}
Executor::Executor(pybind11::buffer engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType,
tle::ExecutorConfig const& executorConfig, std::optional<pybind11::dict> managedWeights)
{
py::buffer_info info = engineBuffer.request();
uint8_t const* data = reinterpret_cast<uint8_t const*>(info.ptr);
size_t size = info.size;
std::optional<std::map<std::string, tle::Tensor>> managedWeightsMap = std::nullopt;
if (managedWeights.has_value() && !managedWeights.value().empty())
{
managedWeightsMap = std::map<std::string, tle::Tensor>();
for (auto const& item : managedWeights.value())
{
std::string name = item.first.cast<std::string>();
py::array array = item.second.cast<py::array>();
managedWeightsMap->emplace(name, numpyToTensor(array));
}
}
mExecutor = std::make_unique<tle::Executor>(
tle::BufferView(data, size), jsonConfigStr, modelType, executorConfig, managedWeightsMap);
}
Executor::Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr,
std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType,
tle::ExecutorConfig const& executorConfig)
{
uint8_t const* encoderData = reinterpret_cast<uint8_t const*>(encoderEngineBuffer.data());
size_t encoderSize = encoderEngineBuffer.size();
uint8_t const* decoderData = reinterpret_cast<uint8_t const*>(decoderEngineBuffer.data());
size_t decoderSize = decoderEngineBuffer.size();
mExecutor = std::make_unique<tle::Executor>(tle::BufferView(encoderData, encoderSize), encoderJsonConfigStr,
tle::BufferView(decoderData, decoderSize), decoderJsonConfigStr, modelType, executorConfig);
}
py::object Executor::enter()
{
TLLM_CHECK(static_cast<bool>(mExecutor));
return py::cast(this);
}
void Executor::exit(
[[maybe_unused]] py::handle type, [[maybe_unused]] py::handle value, [[maybe_unused]] py::handle traceback)
{
shutdown();
mExecutor = nullptr;
}
void Executor::shutdown()
{
// NOTE: we must release the GIL here. Executor has spawned a thread for the execution loop. That thread must be
// able to do forward progress for the shutdown process to succeed. It takes the GIL during its callbacks, so
// we release it now. Note that we shouldn't do anything related to python objects after that.
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
py::gil_scoped_release release;
mExecutor->shutdown();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void Executor::initBindings(py::module_& m)
{
py::class_<Executor>(m, "Executor")
.def(py::init<std::filesystem::path const&, tle::ModelType, tle::ExecutorConfig const&>(),
py::arg("model_path"), py::arg("model_type"), py::arg("executor_config"))
.def(py::init<std::filesystem::path const&, std::filesystem::path const&, tle::ModelType,
tle::ExecutorConfig const&>(),
py::arg("encoder_model_path"), py::arg("decoder_model_path"), py::arg("model_type"),
py::arg("executor_config"))
.def(py::init<py::buffer, std::string const&, tle::ModelType, tle::ExecutorConfig const&, py::dict>(),
py::arg("engine_buffer"), py::arg("json_config_str"), py::arg("model_type"), py::arg("executor_config"),
py::arg("managed_weights") = py::dict())
.def(py::init<std::string const&, std::string const&, std::string const&, std::string const&, tle::ModelType,
tle::ExecutorConfig const&>(),
py::arg("encoder_engine_buffer"), py::arg("encoder_json_config_str"), py::arg("decoder_engine_buffer"),
py::arg("decoder_json_config_str"), py::arg("model_type"), py::arg("executor_config"))
.def("shutdown", &Executor::shutdown)
.def("__enter__", &Executor::enter)
.def("__exit__", &Executor::exit)
.def("enqueue_request", &Executor::enqueueRequest, py::arg("request"))
.def("enqueue_requests", &Executor::enqueueRequests, py::arg("requests"))
.def("await_responses",
py::overload_cast<std::optional<std::chrono::milliseconds> const&>(&Executor::awaitResponses),
py::arg("timeout") = py::none())
.def("await_responses",
py::overload_cast<tle::IdType const&, std::optional<std::chrono::milliseconds> const&>(
&Executor::awaitResponses),
py::arg("id"), py::arg("timeout") = py::none())
.def("await_responses",
py::overload_cast<std::vector<tle::IdType> const&, std::optional<std::chrono::milliseconds> const&>(
&Executor::awaitResponses),
py::arg("ids"), py::arg("timeout") = py::none())
.def("get_num_responses_ready", &Executor::getNumResponsesReady, py::arg("id") = py::none())
.def("cancel_request", &Executor::cancelRequest, py::arg("id") = py::none())
.def("get_latest_iteration_stats", &Executor::getLatestIterationStats)
.def("get_latest_request_stats", &Executor::getLatestRequestStats)
.def("get_latest_debug_tensors", &Executor::getLatestDebugTensors)
.def("can_enqueue_requests", &Executor::canEnqueueRequests)
.def("get_kv_cache_event_manager", &Executor::getKVCacheEventManager);
}
} // namespace tensorrt_llm::pybind::executor

View File

@ -1,129 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tle = tensorrt_llm::executor;
namespace tensorrt_llm::pybind::executor
{
class Executor
{
public:
Executor(
std::filesystem::path const& modelPath, tle::ModelType modelType, tle::ExecutorConfig const& executorConfig);
Executor(std::filesystem::path const& encoderModelPath, std::filesystem::path const& decoderModelPath,
tle::ModelType modelType, tle::ExecutorConfig const& executorConfig);
Executor(pybind11::buffer engineBuffer, std::string const& jsonConfigStr, tle::ModelType modelType,
tle::ExecutorConfig const& executorConfig, std::optional<pybind11::dict> managedWeights);
Executor(std::string const& encoderEngineBuffer, std::string const& encoderJsonConfigStr,
std::string const& decoderEngineBuffer, std::string const& decoderJsonConfigStr, tle::ModelType modelType,
tle::ExecutorConfig const& executorConfig);
pybind11::object enter();
void exit([[maybe_unused]] pybind11::handle type, [[maybe_unused]] pybind11::handle value,
[[maybe_unused]] pybind11::handle traceback);
void shutdown();
[[nodiscard]] tle::IdType enqueueRequest(tle::Request const& request)
{
return mExecutor->enqueueRequest(request);
}
[[nodiscard]] std::vector<tle::IdType> enqueueRequests(std::vector<tle::Request> const& requests)
{
return mExecutor->enqueueRequests(requests);
}
[[nodiscard]] std::vector<tle::Response> awaitResponses(
std::optional<std::chrono::milliseconds> const& timeout = std::nullopt)
{
// Await responses blocks until a response is received. Release GIL so that it can be ran in a background
// thread.
pybind11::gil_scoped_release release;
return mExecutor->awaitResponses(timeout);
}
[[nodiscard]] std::vector<tle::Response> awaitResponses(
tle::IdType const& requestId, std::optional<std::chrono::milliseconds> const& timeout = std::nullopt)
{
// Await responses blocks until a response is received. Release GIL so that it can be ran in a background
// thread.
pybind11::gil_scoped_release release;
return mExecutor->awaitResponses(requestId, timeout);
}
[[nodiscard]] std::vector<std::vector<tle::Response>> awaitResponses(std::vector<tle::IdType> const& requestIds,
std::optional<std::chrono::milliseconds> const& timeout = std::nullopt)
{
// Await responses blocks until a response is received. Release GIL so that it can be ran in a background
// thread.
pybind11::gil_scoped_release release;
return mExecutor->awaitResponses(requestIds, timeout);
}
[[nodiscard]] tle::SizeType32 getNumResponsesReady(std::optional<tle::IdType> const& requestId = std::nullopt) const
{
return mExecutor->getNumResponsesReady(requestId);
}
void cancelRequest(tle::IdType requestId)
{
mExecutor->cancelRequest(requestId);
}
std::deque<tle::IterationStats> getLatestIterationStats()
{
return mExecutor->getLatestIterationStats();
}
std::deque<tle::RequestStatsPerIteration> getLatestRequestStats()
{
return mExecutor->getLatestRequestStats();
}
std::deque<tle::DebugTensorsPerIteration> getLatestDebugTensors()
{
return mExecutor->getLatestDebugTensors();
}
[[nodiscard]] bool canEnqueueRequests() const
{
return mExecutor->canEnqueueRequests();
}
[[nodiscard]] std::optional<std::shared_ptr<tle::KVCacheEventManager>> getKVCacheEventManager() const
{
return mExecutor->getKVCacheEventManager();
}
static void initBindings(pybind11::module_& m);
private:
std::unique_ptr<tle::Executor> mExecutor;
};
} // namespace tensorrt_llm::pybind::executor

View File

@ -1,642 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "executorConfig.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include <optional>
#include <pybind11/cast.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/torch.h>
#include <vector>
namespace py = pybind11;
namespace tle = tensorrt_llm::executor;
using SizeType32 = tle::SizeType32;
using RuntimeDefaults = tensorrt_llm::runtime::RuntimeDefaults;
namespace tensorrt_llm::pybind::executor
{
void initConfigBindings(pybind11::module_& m)
{
py::enum_<tle::BatchingType>(m, "BatchingType")
.value("STATIC", tle::BatchingType::kSTATIC)
.value("INFLIGHT", tle::BatchingType::kINFLIGHT);
auto dynamicBatchConfigGetstate = [](tle::DynamicBatchConfig const& self)
{
return py::make_tuple(self.getEnableBatchSizeTuning(), self.getEnableMaxNumTokensTuning(),
self.getDynamicBatchMovingAverageWindow(), self.getBatchSizeTable());
};
auto dynamicBatchConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 4)
{
throw std::runtime_error("Invalid state!");
}
return tle::DynamicBatchConfig(state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<SizeType32>(),
state[3].cast<std::vector<std::pair<SizeType32, SizeType32>>>());
};
py::class_<tle::DynamicBatchConfig>(m, "DynamicBatchConfig")
.def(py::init<bool, bool, SizeType32>(), py::arg("enable_batch_size_tuning"),
py::arg("enable_max_num_tokens_tuning"), py::arg("dynamic_batch_moving_average_window"))
.def_property_readonly("enable_batch_size_tuning", &tle::DynamicBatchConfig::getEnableBatchSizeTuning)
.def_property_readonly("enable_max_num_tokens_tuning", &tle::DynamicBatchConfig::getEnableMaxNumTokensTuning)
.def_property_readonly(
"dynamic_batch_moving_average_window", &tle::DynamicBatchConfig::getDynamicBatchMovingAverageWindow)
.def(py::pickle(dynamicBatchConfigGetstate, dynamicBatchConfigSetstate));
auto schedulerConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 3)
{
throw std::runtime_error("Invalid state!");
}
return tle::SchedulerConfig(state[0].cast<tle::CapacitySchedulerPolicy>(),
state[1].cast<std::optional<tle::ContextChunkingPolicy>>(),
state[2].cast<std::optional<tle::DynamicBatchConfig>>());
};
auto schedulerConfigGetstate = [](tle::SchedulerConfig const& self)
{
return py::make_tuple(
self.getCapacitySchedulerPolicy(), self.getContextChunkingPolicy(), self.getDynamicBatchConfig());
};
py::class_<tle::SchedulerConfig>(m, "SchedulerConfig")
.def(py::init<tle::CapacitySchedulerPolicy, std::optional<tle::ContextChunkingPolicy>,
std::optional<tle::DynamicBatchConfig>>(),
py::arg_v("capacity_scheduler_policy", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT,
"CapacitySchedulerPolicy.GUARANTEED_NO_EVICT"),
py::arg("context_chunking_policy") = py::none(), py::arg("dynamic_batch_config") = py::none())
.def_property_readonly("capacity_scheduler_policy", &tle::SchedulerConfig::getCapacitySchedulerPolicy)
.def_property_readonly("context_chunking_policy", &tle::SchedulerConfig::getContextChunkingPolicy)
.def_property_readonly("dynamic_batch_config", &tle::SchedulerConfig::getDynamicBatchConfig)
.def(py::pickle(schedulerConfigGetstate, schedulerConfigSetstate));
py::class_<RuntimeDefaults>(m, "RuntimeDefaults")
.def(py::init<std::optional<std::vector<SizeType32>>, std::optional<SizeType32>>(),
py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none())
.def_readonly("max_attention_window", &RuntimeDefaults::maxAttentionWindowVec)
.def_readonly("sink_token_length", &RuntimeDefaults::sinkTokenLength);
auto kvCacheConfigGetstate = [](tle::KvCacheConfig const& self)
{
return py::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(),
self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(),
self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(),
self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(),
self.getAttentionDpEventsGatherPeriodMs(), self.getMaxGpuTotalBytes());
};
auto kvCacheConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 15)
{
throw std::runtime_error("Invalid state!");
}
return tle::KvCacheConfig(state[0].cast<bool>(), state[1].cast<std::optional<SizeType32>>(),
state[2].cast<std::optional<std::vector<SizeType32>>>(), state[3].cast<std::optional<SizeType32>>(),
state[4].cast<std::optional<float>>(), state[5].cast<std::optional<size_t>>(), state[6].cast<bool>(),
state[7].cast<std::optional<float>>(), state[8].cast<std::optional<tle::RetentionPriority>>(),
state[9].cast<size_t>(), state[10].cast<bool>(), state[11].cast<bool>(), state[12].cast<bool>(),
state[13].cast<SizeType32>(), std::nullopt, state[14].cast<uint64_t>());
};
py::class_<tle::KvCacheConfig>(m, "KvCacheConfig")
.def(py::init<bool, std::optional<SizeType32> const&, std::optional<std::vector<SizeType32>> const&,
std::optional<SizeType32> const&, std::optional<float> const&, std::optional<size_t> const&, bool,
std::optional<float> const&, std::optional<tle::RetentionPriority>, size_t const&, bool, bool, bool,
SizeType32, std::optional<RuntimeDefaults> const&, uint64_t const&>(),
py::arg("enable_block_reuse") = true, py::arg("max_tokens") = py::none(),
py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none(),
py::arg("free_gpu_memory_fraction") = py::none(), py::arg("host_cache_size") = py::none(),
py::arg("onboard_blocks") = true, py::arg("cross_kv_cache_fraction") = py::none(),
py::arg("secondary_offload_min_priority") = py::none(), py::arg("event_buffer_max_size") = 0, py::kw_only(),
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, py::arg("use_uvm") = false,
py::arg("attention_dp_events_gather_period_ms") = 5, py::arg("runtime_defaults") = py::none(),
py::arg("max_gpu_total_bytes") = 0)
.def_property(
"enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse)
.def_property("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens)
.def_property("max_attention_window", &tle::KvCacheConfig::getMaxAttentionWindowVec,
&tle::KvCacheConfig::setMaxAttentionWindowVec)
.def_property(
"sink_token_length", &tle::KvCacheConfig::getSinkTokenLength, &tle::KvCacheConfig::setSinkTokenLength)
.def_property("free_gpu_memory_fraction", &tle::KvCacheConfig::getFreeGpuMemoryFraction,
&tle::KvCacheConfig::setFreeGpuMemoryFraction)
.def_property(
"max_gpu_total_bytes", &tle::KvCacheConfig::getMaxGpuTotalBytes, &tle::KvCacheConfig::setMaxGpuTotalBytes)
.def_property("host_cache_size", &tle::KvCacheConfig::getHostCacheSize, &tle::KvCacheConfig::setHostCacheSize)
.def_property("onboard_blocks", &tle::KvCacheConfig::getOnboardBlocks, &tle::KvCacheConfig::setOnboardBlocks)
.def_property("cross_kv_cache_fraction", &tle::KvCacheConfig::getCrossKvCacheFraction,
&tle::KvCacheConfig::setCrossKvCacheFraction)
.def_property("secondary_offload_min_priority", &tle::KvCacheConfig::getSecondaryOffloadMinPriority,
&tle::KvCacheConfig::setSecondaryOffloadMinPriority)
.def_property("event_buffer_max_size", &tle::KvCacheConfig::getEventBufferMaxSize,
&tle::KvCacheConfig::setEventBufferMaxSize)
.def_property("enable_partial_reuse", &tle::KvCacheConfig::getEnablePartialReuse,
&tle::KvCacheConfig::setEnablePartialReuse)
.def_property("copy_on_partial_reuse", &tle::KvCacheConfig::getCopyOnPartialReuse,
&tle::KvCacheConfig::setCopyOnPartialReuse)
.def_property("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm)
.def_property("attention_dp_events_gather_period_ms", &tle::KvCacheConfig::getAttentionDpEventsGatherPeriodMs,
&tle::KvCacheConfig::setAttentionDpEventsGatherPeriodMs)
.def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults)
.def(py::pickle(kvCacheConfigGetstate, kvCacheConfigSetstate));
py::class_<tle::OrchestratorConfig>(m, "OrchestratorConfig")
.def(py::init<bool, std::string, std::shared_ptr<mpi::MpiComm>, bool>(), py::arg("is_orchestrator") = true,
py::arg("worker_executable_path") = "", py::arg("orch_leader_comm") = nullptr,
py::arg("spawn_processes") = true)
.def_property(
"is_orchestrator", &tle::OrchestratorConfig::getIsOrchestrator, &tle::OrchestratorConfig::setIsOrchestrator)
.def_property("worker_executable_path", &tle::OrchestratorConfig::getWorkerExecutablePath,
&tle::OrchestratorConfig::setWorkerExecutablePath)
.def_property("orch_leader_comm", &tle::OrchestratorConfig::getOrchLeaderComm,
&tle::OrchestratorConfig::setOrchLeaderComm)
.def_property("spawn_processes", &tle::OrchestratorConfig::getSpawnProcesses,
&tle::OrchestratorConfig::setSpawnProcesses);
auto parallelConfigGetstate = [](tle::ParallelConfig const& self)
{
return py::make_tuple(self.getCommunicationType(), self.getCommunicationMode(), self.getDeviceIds(),
self.getParticipantIds(), self.getOrchestratorConfig(), self.getNumNodes());
};
auto parallelConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 6)
{
throw std::runtime_error("Invalid state!");
}
return tle::ParallelConfig(state[0].cast<tle::CommunicationType>(), state[1].cast<tle::CommunicationMode>(),
state[2].cast<std::optional<std::vector<SizeType32>>>(),
state[3].cast<std::optional<std::vector<SizeType32>>>(),
state[4].cast<std::optional<tle::OrchestratorConfig>>(), state[5].cast<std::optional<SizeType32>>());
};
py::class_<tle::ParallelConfig>(m, "ParallelConfig")
.def(py::init<tle::CommunicationType, tle::CommunicationMode, std::optional<std::vector<SizeType32>> const&,
std::optional<std::vector<SizeType32>> const&, std::optional<tle::OrchestratorConfig> const&,
std::optional<SizeType32> const&>(),
py::arg_v("communication_type", tle::CommunicationType::kMPI, "CommunicationType.MPI"),
py::arg_v("communication_mode", tle::CommunicationMode::kLEADER, "CommunicationMode.LEADER"),
py::arg("device_ids") = py::none(), py::arg("participant_ids") = py::none(),
py::arg("orchestrator_config") = py::none(), py::arg("num_nodes") = py::none())
.def_property("communication_type", &tle::ParallelConfig::getCommunicationType,
&tle::ParallelConfig::setCommunicationType)
.def_property("communication_mode", &tle::ParallelConfig::getCommunicationMode,
&tle::ParallelConfig::setCommunicationMode)
.def_property("device_ids", &tle::ParallelConfig::getDeviceIds, &tle::ParallelConfig::setDeviceIds)
.def_property(
"participant_ids", &tle::ParallelConfig::getParticipantIds, &tle::ParallelConfig::setParticipantIds)
.def_property("orchestrator_config", &tle::ParallelConfig::getOrchestratorConfig,
&tle::ParallelConfig::setOrchestratorConfig)
.def_property("num_nodes", &tle::ParallelConfig::getNumNodes, &tle::ParallelConfig::setNumNodes)
.def(py::pickle(parallelConfigGetstate, parallelConfigSetstate));
auto peftCacheConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 11)
{
throw std::runtime_error("Invalid state!");
}
return tle::PeftCacheConfig(state[0].cast<SizeType32>(), state[1].cast<SizeType32>(),
state[2].cast<SizeType32>(), state[3].cast<SizeType32>(), state[4].cast<SizeType32>(),
state[5].cast<SizeType32>(), state[6].cast<SizeType32>(), state[7].cast<SizeType32>(),
state[8].cast<SizeType32>(), state[9].cast<std::optional<float>>(),
state[10].cast<std::optional<size_t>>());
};
auto peftCacheConfigGetstate = [](tle::PeftCacheConfig const& self)
{
return py::make_tuple(self.getNumHostModuleLayer(), self.getNumDeviceModuleLayer(),
self.getOptimalAdapterSize(), self.getMaxAdapterSize(), self.getNumPutWorkers(), self.getNumEnsureWorkers(),
self.getNumCopyStreams(), self.getMaxPagesPerBlockHost(), self.getMaxPagesPerBlockDevice(),
self.getDeviceCachePercent(), self.getHostCacheSize());
};
py::class_<tle::PeftCacheConfig>(m, "PeftCacheConfig")
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32,
SizeType32, std::optional<float> const&, std::optional<size_t> const&,
std::optional<std::string> const&>(),
py::arg("num_host_module_layer") = 0, py::arg("num_device_module_layer") = 0,
py::arg("optimal_adapter_size") = 8, py::arg("max_adapter_size") = 64, py::arg("num_put_workers") = 1,
py::arg("num_ensure_workers") = 1, py::arg("num_copy_streams") = 1,
py::arg("max_pages_per_block_host") = 24, py::arg("max_pages_per_block_device") = 8,
py::arg("device_cache_percent") = py::none(), py::arg("host_cache_size") = py::none(),
py::arg("lora_prefetch_dir") = py::none())
.def_property_readonly("num_host_module_layer", &tle::PeftCacheConfig::getNumHostModuleLayer)
.def_property_readonly("num_device_module_layer", &tle::PeftCacheConfig::getNumDeviceModuleLayer)
.def_property_readonly("optimal_adapter_size", &tle::PeftCacheConfig::getOptimalAdapterSize)
.def_property_readonly("max_adapter_size", &tle::PeftCacheConfig::getMaxAdapterSize)
.def_property_readonly("num_put_workers", &tle::PeftCacheConfig::getNumPutWorkers)
.def_property_readonly("num_ensure_workers", &tle::PeftCacheConfig::getNumEnsureWorkers)
.def_property_readonly("num_copy_streams", &tle::PeftCacheConfig::getNumCopyStreams)
.def_property_readonly("max_pages_per_block_host", &tle::PeftCacheConfig::getMaxPagesPerBlockHost)
.def_property_readonly("max_pages_per_block_device", &tle::PeftCacheConfig::getMaxPagesPerBlockDevice)
.def_property_readonly("device_cache_percent", &tle::PeftCacheConfig::getDeviceCachePercent)
.def_property_readonly("host_cache_size", &tle::PeftCacheConfig::getHostCacheSize)
.def_property_readonly("lora_prefetch_dir", &tle::PeftCacheConfig::getLoraPrefetchDir)
.def(py::pickle(peftCacheConfigGetstate, peftCacheConfigSetstate));
auto decodingConfigGetstate = [](tle::DecodingConfig const& self)
{
return py::make_tuple(
self.getDecodingMode(), self.getLookaheadDecodingConfig(), self.getMedusaChoices(), self.getEagleConfig());
};
auto decodingConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 4)
{
throw std::runtime_error("Invalid state!");
}
return tle::DecodingConfig(state[0].cast<std::optional<tle::DecodingMode>>(), // DecodingMode
state[1].cast<std::optional<tle::LookaheadDecodingConfig>>(), // LookaheadDecodingConfig
state[2].cast<std::optional<tle::MedusaChoices>>(), // MedusaChoices
state[3].cast<std::optional<tle::EagleConfig>>() // EagleConfig
);
};
py::class_<tle::DecodingConfig>(m, "DecodingConfig")
.def(py::init<std::optional<tle::DecodingMode>, std::optional<tle::LookaheadDecodingConfig>,
std::optional<tle::MedusaChoices>, std::optional<tle::EagleConfig>>(),
py::arg("decoding_mode") = py::none(), py::arg("lookahead_decoding_config") = py::none(),
py::arg("medusa_choices") = py::none(), py::arg("eagle_config") = py::none())
.def_property("decoding_mode", &tle::DecodingConfig::getDecodingMode, &tle::DecodingConfig::setDecodingMode)
.def_property("lookahead_decoding_config", &tle::DecodingConfig::getLookaheadDecodingConfig,
&tle::DecodingConfig::setLookaheadDecodingConfig)
.def_property("medusa_choices", &tle::DecodingConfig::getMedusaChoices, &tle::DecodingConfig::setMedusaChoices)
.def_property("eagle_config", &tle::DecodingConfig::getEagleConfig, &tle::DecodingConfig::setEagleConfig)
.def(py::pickle(decodingConfigGetstate, decodingConfigSetstate));
auto debugConfigGetstate = [](tle::DebugConfig const& self)
{
return py::make_tuple(self.getDebugInputTensors(), self.getDebugOutputTensors(), self.getDebugTensorNames(),
self.getDebugTensorsMaxIterations());
};
auto debugConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 4)
{
throw std::runtime_error("Invalid state!");
}
return tle::DebugConfig(state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<std::vector<std::string>>(),
state[3].cast<SizeType32>());
};
py::class_<tle::DebugConfig>(m, "DebugConfig")
.def(py::init<bool, bool, std::vector<std::string>, SizeType32>(), py::arg("debug_input_tensors") = false,
py::arg("debug_output_tensors") = false, py::arg("debug_tensor_names") = py::none(),
py::arg("debug_tensors_max_iterations") = false)
.def_property(
"debug_input_tensors", &tle::DebugConfig::getDebugInputTensors, &tle::DebugConfig::setDebugInputTensors)
.def_property(
"debug_output_tensors", &tle::DebugConfig::getDebugOutputTensors, &tle::DebugConfig::setDebugOutputTensors)
.def_property(
"debug_tensor_names", &tle::DebugConfig::getDebugTensorNames, &tle::DebugConfig::setDebugTensorNames)
.def_property("debug_tensors_max_iterations", &tle::DebugConfig::getDebugTensorsMaxIterations,
&tle::DebugConfig::setDebugTensorsMaxIterations)
.def(py::pickle(debugConfigGetstate, debugConfigSetstate));
auto logitsPostProcessorConfigGetstate = [](tle::LogitsPostProcessorConfig const& self)
{ return py::make_tuple(self.getProcessorMap(), self.getProcessorBatched(), self.getReplicate()); };
auto logitsPostProcessorConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 3)
{
throw std::runtime_error("Invalid LogitsPostProcessorConfig state!");
}
return tle::LogitsPostProcessorConfig(state[0].cast<std::optional<tle::LogitsPostProcessorMap>>(),
state[1].cast<std::optional<tle::LogitsPostProcessorBatched>>(), state[2].cast<bool>());
};
py::class_<tle::LogitsPostProcessorConfig>(m, "LogitsPostProcessorConfig")
.def(py::init<std::optional<tle::LogitsPostProcessorMap>, std::optional<tle::LogitsPostProcessorBatched>,
bool>(),
py::arg("processor_map") = py::none(), py::arg("processor_batched") = py::none(),
py::arg("replicate") = true)
.def_property("processor_map", &tle::LogitsPostProcessorConfig::getProcessorMap,
&tle::LogitsPostProcessorConfig::setProcessorMap)
.def_property("processor_batched", &tle::LogitsPostProcessorConfig::getProcessorBatched,
&tle::LogitsPostProcessorConfig::setProcessorBatched)
.def_property(
"replicate", &tle::LogitsPostProcessorConfig::getReplicate, &tle::LogitsPostProcessorConfig::setReplicate)
.def(py::pickle(logitsPostProcessorConfigGetstate, logitsPostProcessorConfigSetstate));
auto extendedRuntimePerfKnobConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 4)
{
throw std::runtime_error("Invalid extendedRuntimePerfKnobConfig state!");
}
return tle::ExtendedRuntimePerfKnobConfig(
state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<bool>(), state[3].cast<SizeType32>());
};
auto extendedRuntimePerfKnobConfigGetstate = [](tle::ExtendedRuntimePerfKnobConfig const& self)
{
return py::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc(), self.getCudaGraphMode(),
self.getCudaGraphCacheSize());
};
py::class_<tle::ExtendedRuntimePerfKnobConfig>(m, "ExtendedRuntimePerfKnobConfig")
.def(
py::init<bool, bool>(), py::arg("multi_block_mode") = true, py::arg("enable_context_fmha_fp32_acc") = false)
.def_property("multi_block_mode", &tle::ExtendedRuntimePerfKnobConfig::getMultiBlockMode,
&tle::ExtendedRuntimePerfKnobConfig::setMultiBlockMode)
.def_property("enable_context_fmha_fp32_acc", &tle::ExtendedRuntimePerfKnobConfig::getEnableContextFMHAFP32Acc,
&tle::ExtendedRuntimePerfKnobConfig::setEnableContextFMHAFP32Acc)
.def_property("cuda_graph_mode", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphMode,
&tle::ExtendedRuntimePerfKnobConfig::setCudaGraphMode)
.def_property("cuda_graph_cache_size", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphCacheSize,
&tle::ExtendedRuntimePerfKnobConfig::setCudaGraphCacheSize)
.def(py::pickle(extendedRuntimePerfKnobConfigGetstate, extendedRuntimePerfKnobConfigSetstate));
auto SpeculativeDecodingConfigGetState
= [](tle::SpeculativeDecodingConfig const& self) { return py::make_tuple(self.fastLogits); };
auto SpeculativeDecodingConfigSetState = [](py::tuple const& state)
{
if (state.size() != 1)
{
throw std::runtime_error("Invalid SpeculativeDecodingConfig state!");
}
return tle::SpeculativeDecodingConfig(state[0].cast<bool>());
};
py::class_<tle::SpeculativeDecodingConfig>(m, "SpeculativeDecodingConfig")
.def(py::init<bool>(), py::arg("fast_logits") = false)
.def_readwrite("fast_logits", &tle::SpeculativeDecodingConfig::fastLogits)
.def(py::pickle(SpeculativeDecodingConfigGetState, SpeculativeDecodingConfigSetState));
// Guided decoding config
auto pyGuidedDecodingConfig = py::class_<tle::GuidedDecodingConfig>(m, "GuidedDecodingConfig");
py::enum_<tle::GuidedDecodingConfig::GuidedDecodingBackend>(pyGuidedDecodingConfig, "GuidedDecodingBackend")
.value("XGRAMMAR", tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR)
.value("LLGUIDANCE", tle::GuidedDecodingConfig::GuidedDecodingBackend::kLLGUIDANCE);
auto guidedDecodingConfigGetstate = [](tle::GuidedDecodingConfig const& self) {
return py::make_tuple(
self.getBackend(), self.getEncodedVocab(), self.getTokenizerStr(), self.getStopTokenIds());
};
auto guidedDecodingConfigSetstate = [](py::tuple state)
{
if (state.size() != 4)
{
throw std::runtime_error("Invalid GuidedDecodingConfig state!");
}
return tle::GuidedDecodingConfig(state[0].cast<tle::GuidedDecodingConfig::GuidedDecodingBackend>(),
state[1].cast<std::optional<std::vector<std::string>>>(), state[2].cast<std::optional<std::string>>(),
state[3].cast<std::optional<std::vector<tle::TokenIdType>>>());
};
pyGuidedDecodingConfig
.def(py::init<tle::GuidedDecodingConfig::GuidedDecodingBackend, std::optional<std::vector<std::string>>,
std::optional<std::string>, std::optional<std::vector<tle::TokenIdType>>>(),
py::arg("backend"), py::arg("encoded_vocab") = py::none(), py::arg("tokenizer_str") = py::none(),
py::arg("stop_token_ids") = py::none())
.def_property("backend", &tle::GuidedDecodingConfig::getBackend, &tle::GuidedDecodingConfig::setBackend)
.def_property(
"encoded_vocab", &tle::GuidedDecodingConfig::getEncodedVocab, &tle::GuidedDecodingConfig::setEncodedVocab)
.def_property(
"tokenizer_str", &tle::GuidedDecodingConfig::getTokenizerStr, &tle::GuidedDecodingConfig::setTokenizerStr)
.def_property(
"stop_token_ids", &tle::GuidedDecodingConfig::getStopTokenIds, &tle::GuidedDecodingConfig::setStopTokenIds)
.def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate));
auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self)
{ return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer(), self.getKvTransferTimeoutMs()); };
auto cacheTransceiverConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 3)
{
throw std::runtime_error("Invalid CacheTransceiverConfig state!");
}
return tle::CacheTransceiverConfig(state[0].cast<tle::CacheTransceiverConfig::BackendType>(),
state[1].cast<std::optional<size_t>>(), state[2].cast<std::optional<int>>());
};
py::enum_<tle::CacheTransceiverConfig::BackendType>(m, "CacheTransceiverBackendType")
.value("DEFAULT", tle::CacheTransceiverConfig::BackendType::DEFAULT)
.value("MPI", tle::CacheTransceiverConfig::BackendType::MPI)
.value("UCX", tle::CacheTransceiverConfig::BackendType::UCX)
.value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL)
.value("MOONCAKE", tle::CacheTransceiverConfig::BackendType::MOONCAKE)
.def("from_string",
[](std::string const& str)
{
if (str == "DEFAULT" || str == "default")
return tle::CacheTransceiverConfig::BackendType::DEFAULT;
if (str == "MPI" || str == "mpi")
return tle::CacheTransceiverConfig::BackendType::MPI;
if (str == "UCX" || str == "ucx")
return tle::CacheTransceiverConfig::BackendType::UCX;
if (str == "NIXL" || str == "nixl")
return tle::CacheTransceiverConfig::BackendType::NIXL;
if (str == "MOONCAKE" || str == "mooncake")
return tle::CacheTransceiverConfig::BackendType::MOONCAKE;
throw std::runtime_error("Invalid backend type: " + str);
});
py::class_<tle::CacheTransceiverConfig>(m, "CacheTransceiverConfig")
.def(py::init<std::optional<tle::CacheTransceiverConfig::BackendType>, std::optional<size_t>,
std::optional<int>, std::optional<int>>(),
py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt,
py::arg("kv_transfer_timeout_ms") = std::nullopt,
py::arg("kv_transfer_sender_future_timeout_ms") = std::nullopt)
.def_property(
"backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType)
.def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer,
&tle::CacheTransceiverConfig::setMaxTokensInBuffer)
.def_property("kv_transfer_timeout_ms", &tle::CacheTransceiverConfig::getKvTransferTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferTimeoutMs)
.def_property("kv_transfer_sender_future_timeout_ms",
&tle::CacheTransceiverConfig::getKvTransferSenderFutureTimeoutMs,
&tle::CacheTransceiverConfig::setKvTransferSenderFutureTimeoutMs)
.def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate));
auto executorConfigGetState = [](py::object const& self)
{
auto& c = self.cast<tle::ExecutorConfig&>();
// Return a tuple containing C++ data and the Python __dict__
auto cpp_states = py::make_tuple(c.getMaxBeamWidth(), c.getSchedulerConfig(), c.getKvCacheConfig(),
c.getEnableChunkedContext(), c.getNormalizeLogProbs(), c.getIterStatsMaxIterations(),
c.getRequestStatsMaxIterations(), c.getBatchingType(), c.getMaxBatchSize(), c.getMaxNumTokens(),
c.getParallelConfig(), c.getPeftCacheConfig(), c.getLogitsPostProcessorConfig(), c.getDecodingConfig(),
c.getUseGpuDirectStorage(), c.getGpuWeightsPercent(), c.getMaxQueueSize(),
c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(),
c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(),
c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(),
c.getPromptTableOffloading(), c.getEnableTrtOverlap(), c.getFailFastOnAttentionWindowTooLarge());
auto pickle_tuple = py::make_tuple(cpp_states, py::getattr(self, "__dict__"));
return pickle_tuple;
};
auto executorConfigSetState = [](py::tuple const& state)
{
if (state.size() != 2)
{
throw std::runtime_error("Invalid state!");
}
// Restore C++ data
auto cpp_states = state[0].cast<py::tuple>();
if (cpp_states.size() != 29)
{
throw std::runtime_error("Invalid cpp_states!");
}
auto ec = tle::ExecutorConfig( //
cpp_states[0].cast<SizeType32>(), // MaxBeamWidth
cpp_states[1].cast<tle::SchedulerConfig>(), // SchedulerConfig
cpp_states[2].cast<tle::KvCacheConfig>(), // KvCacheConfig
cpp_states[3].cast<bool>(), // EnableChunkedContext
cpp_states[4].cast<bool>(), // NormalizeLogProbs
cpp_states[5].cast<SizeType32>(), // IterStatsMaxIterations
cpp_states[6].cast<SizeType32>(), // RequestStatsMaxIterations
cpp_states[7].cast<tle::BatchingType>(), // BatchingType
cpp_states[8].cast<std::optional<SizeType32>>(), // MaxBatchSize
cpp_states[9].cast<std::optional<SizeType32>>(), // MaxNumTokens
cpp_states[10].cast<std::optional<tle::ParallelConfig>>(), // ParallelConfig
cpp_states[11].cast<std::optional<tle::PeftCacheConfig>>(), // PeftCacheConfig
cpp_states[12].cast<std::optional<tle::LogitsPostProcessorConfig>>(), // LogitsPostProcessorConfig
cpp_states[13].cast<std::optional<tle::DecodingConfig>>(), // DecodingConfig
cpp_states[14].cast<bool>(), // UseGpuDirectStorage
cpp_states[15].cast<float>(), // GpuWeightsPercent
cpp_states[16].cast<std::optional<SizeType32>>(), // MaxQueueSize
cpp_states[17].cast<tle::ExtendedRuntimePerfKnobConfig>(), // ExtendedRuntimePerfKnobConfig
cpp_states[18].cast<std::optional<tle::DebugConfig>>(), // DebugConfig
cpp_states[19].cast<SizeType32>(), // RecvPollPeriodMs
cpp_states[20].cast<uint64_t>(), // MaxSeqIdleMicroseconds
cpp_states[21].cast<std::optional<tle::SpeculativeDecodingConfig>>(), // SpecDecConfig
cpp_states[22].cast<std::optional<tle::GuidedDecodingConfig>>(), // GuidedDecodingConfig
cpp_states[23].cast<std::optional<std::vector<tle::AdditionalModelOutput>>>(), // AdditionalModelOutputs
cpp_states[24].cast<std::optional<tle::CacheTransceiverConfig>>(), // CacheTransceiverConfig
cpp_states[25].cast<bool>(), // GatherGenerationLogits
cpp_states[26].cast<bool>(), // PromptTableOffloading
cpp_states[27].cast<bool>(), // EnableTrtOverlap
cpp_states[28].cast<bool>() // FailFastOnAttentionWindowTooLarge
);
auto py_state = state[1].cast<py::dict>();
return std::make_pair(ec, py_state);
};
py::class_<tle::ExecutorConfig>(m, "ExecutorConfig", pybind11::dynamic_attr())
.def(py::init< //
SizeType32, // MaxBeamWidth
tle::SchedulerConfig const&, // SchedulerConfig
tle::KvCacheConfig const&, // KvCacheConfig
bool, // EnableChunkedContext
bool, // NormalizeLogProbs
SizeType32, // IterStatsMaxIterations
SizeType32, // RequestStatsMaxIterations
tle::BatchingType, // BatchingType
std::optional<SizeType32>, // MaxBatchSize
std::optional<SizeType32>, // MaxNumTokens
std::optional<tle::ParallelConfig>, // ParallelConfig
tle::PeftCacheConfig const&, // PeftCacheConfig
std::optional<tle::LogitsPostProcessorConfig>, // LogitsPostProcessorConfig
std::optional<tle::DecodingConfig>, // DecodingConfig
bool, // UseGpuDirectStorage
float, // GpuWeightsPercent
std::optional<SizeType32>, // MaxQueueSize
tle::ExtendedRuntimePerfKnobConfig const&, // ExtendedRuntimePerfKnobConfig
std::optional<tle::DebugConfig>, // DebugConfig
SizeType32, // RecvPollPeriodMs
uint64_t, // MaxSeqIdleMicroseconds
std::optional<tle::SpeculativeDecodingConfig>, // SpecDecConfig
std::optional<tle::GuidedDecodingConfig>, // GuidedDecodingConfig
std::optional<std::vector<tle::AdditionalModelOutput>>, // AdditionalModelOutputs
std::optional<tle::CacheTransceiverConfig>, // CacheTransceiverConfig
bool, // GatherGenerationLogits
bool, // PromptTableOffloading
bool, // EnableTrtOverlap
bool // FailFastOnAttentionWindowTooLarge
>(),
py::arg("max_beam_width") = 1, py::arg_v("scheduler_config", tle::SchedulerConfig(), "SchedulerConfig()"),
py::arg_v("kv_cache_config", tle::KvCacheConfig(), "KvCacheConfig()"),
py::arg("enable_chunked_context") = false, py::arg("normalize_log_probs") = true,
py::arg("iter_stats_max_iterations") = tle::ExecutorConfig::kDefaultIterStatsMaxIterations,
py::arg("request_stats_max_iterations") = tle::ExecutorConfig::kDefaultRequestStatsMaxIterations,
py::arg_v("batching_type", tle::BatchingType::kINFLIGHT, "BatchingType.INFLIGHT"),
py::arg("max_batch_size") = py::none(), py::arg("max_num_tokens") = py::none(),
py::arg("parallel_config") = py::none(),
py::arg_v("peft_cache_config", tle::PeftCacheConfig(), "PeftCacheConfig()"),
py::arg("logits_post_processor_config") = py::none(), py::arg("decoding_config") = py::none(),
py::arg("use_gpu_direct_storage") = false, py::arg("gpu_weights_percent") = 1.0,
py::arg("max_queue_size") = py::none(),
py::arg_v("extended_runtime_perf_knob_config", tle::ExtendedRuntimePerfKnobConfig(),
"ExtendedRuntimePerfKnobConfig()"),
py::arg("debug_config") = py::none(), py::arg("recv_poll_period_ms") = 0,
py::arg("max_seq_idle_microseconds") = tle::ExecutorConfig::kDefaultMaxSeqIdleMicroseconds,
py::arg("spec_dec_config") = py::none(), py::arg("guided_decoding_config") = py::none(),
py::arg("additional_model_outputs") = py::none(), py::arg("cache_transceiver_config") = py::none(),
py::arg("gather_generation_logits") = false, py::arg("mm_embedding_offloading") = false,
py::arg("enable_trt_overlap") = false, py::arg("fail_fast_on_attention_window_too_large") = false)
.def_property("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth)
.def_property("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize)
.def_property("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens)
.def_property(
"scheduler_config", &tle::ExecutorConfig::getSchedulerConfigRef, &tle::ExecutorConfig::setSchedulerConfig)
.def_property(
"kv_cache_config", &tle::ExecutorConfig::getKvCacheConfigRef, &tle::ExecutorConfig::setKvCacheConfig)
.def_property("enable_chunked_context", &tle::ExecutorConfig::getEnableChunkedContext,
&tle::ExecutorConfig::setEnableChunkedContext)
.def_property("normalize_log_probs", &tle::ExecutorConfig::getNormalizeLogProbs,
&tle::ExecutorConfig::setNormalizeLogProbs)
.def_property("iter_stats_max_iterations", &tle::ExecutorConfig::getIterStatsMaxIterations,
&tle::ExecutorConfig::setIterStatsMaxIterations)
.def_property("request_stats_max_iterations", &tle::ExecutorConfig::getRequestStatsMaxIterations,
&tle::ExecutorConfig::setRequestStatsMaxIterations)
.def_property("batching_type", &tle::ExecutorConfig::getBatchingType, &tle::ExecutorConfig::setBatchingType)
.def_property(
"parallel_config", &tle::ExecutorConfig::getParallelConfig, &tle::ExecutorConfig::setParallelConfig)
.def_property(
"peft_cache_config", &tle::ExecutorConfig::getPeftCacheConfig, &tle::ExecutorConfig::setPeftCacheConfig)
.def_property("logits_post_processor_config", &tle::ExecutorConfig::getLogitsPostProcessorConfig,
&tle::ExecutorConfig::setLogitsPostProcessorConfig)
.def_property(
"decoding_config", &tle::ExecutorConfig::getDecodingConfig, &tle::ExecutorConfig::setDecodingConfig)
.def_property("use_gpu_direct_storage", &tle::ExecutorConfig::getUseGpuDirectStorage,
&tle::ExecutorConfig::setUseGpuDirectStorage)
.def_property("gpu_weights_percent", &tle::ExecutorConfig::getGpuWeightsPercent,
&tle::ExecutorConfig::setGpuWeightsPercent)
.def_property("max_queue_size", &tle::ExecutorConfig::getMaxQueueSize, &tle::ExecutorConfig::setMaxQueueSize)
.def_property("extended_runtime_perf_knob_config", &tle::ExecutorConfig::getExtendedRuntimePerfKnobConfig,
&tle::ExecutorConfig::setExtendedRuntimePerfKnobConfig)
.def_property("debug_config", &tle::ExecutorConfig::getDebugConfig, &tle::ExecutorConfig::setDebugConfig)
.def_property(
"recv_poll_period_ms", &tle::ExecutorConfig::getRecvPollPeriodMs, &tle::ExecutorConfig::setRecvPollPeriodMs)
.def_property("max_seq_idle_microseconds", &tle::ExecutorConfig::getMaxSeqIdleMicroseconds,
&tle::ExecutorConfig::setMaxSeqIdleMicroseconds)
.def_property("spec_dec_config", &tle::ExecutorConfig::getSpecDecConfig, &tle::ExecutorConfig::setSpecDecConfig)
.def_property("guided_decoding_config", &tle::ExecutorConfig::getGuidedDecodingConfig,
&tle::ExecutorConfig::setGuidedDecodingConfig)
.def_property("additional_model_outputs", &tle::ExecutorConfig::getAdditionalModelOutputs,
&tle::ExecutorConfig::setAdditionalModelOutputs)
.def_property("cache_transceiver_config", &tle::ExecutorConfig::getCacheTransceiverConfig,
&tle::ExecutorConfig::setCacheTransceiverConfig)
.def_property("gather_generation_logits", &tle::ExecutorConfig::getGatherGenerationLogits,
&tle::ExecutorConfig::setGatherGenerationLogits)
.def_property("mm_embedding_offloading", &tle::ExecutorConfig::getPromptTableOffloading,
&tle::ExecutorConfig::setPromptTableOffloading)
.def_property(
"enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap)
.def_property("fail_fast_on_attention_window_too_large",
&tle::ExecutorConfig::getFailFastOnAttentionWindowTooLarge,
&tle::ExecutorConfig::setFailFastOnAttentionWindowTooLarge)
.def(py::pickle(executorConfigGetState, executorConfigSetState));
}
} // namespace tensorrt_llm::pybind::executor

View File

@ -1,29 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::executor
{
// Register bindings for executor API.
void initConfigBindings(pybind11::module_& m);
} // namespace tensorrt_llm::pybind::executor

View File

@ -1,927 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "request.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include <pybind11/cast.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <sstream>
#include <optional>
#include <vector>
namespace py = pybind11;
namespace tle = tensorrt_llm::executor;
using Tensor = tle::Tensor;
using SizeType32 = tle::SizeType32;
using FloatType = tle::FloatType;
using VecTokens = tle::VecTokens;
using IdType = tle::IdType;
using VecTokenExtraIds = tle::VecTokenExtraIds;
namespace tensorrt_llm::pybind::executor
{
void initRequestBindings(pybind11::module_& m)
{
py::enum_<tle::RequestType>(m, "RequestType")
.value("REQUEST_TYPE_CONTEXT_AND_GENERATION", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION)
.value("REQUEST_TYPE_CONTEXT_ONLY", tle::RequestType::REQUEST_TYPE_CONTEXT_ONLY)
.value("REQUEST_TYPE_GENERATION_ONLY", tle::RequestType::REQUEST_TYPE_GENERATION_ONLY);
py::enum_<tle::FinishReason>(m, "FinishReason")
.value("NOT_FINISHED", tle::FinishReason::kNOT_FINISHED)
.value("END_ID", tle::FinishReason::kEND_ID)
.value("STOP_WORDS", tle::FinishReason::kSTOP_WORDS)
.value("LENGTH", tle::FinishReason::kLENGTH)
.value("TIMED_OUT", tle::FinishReason::kTIMED_OUT)
.value("CANCELLED", tle::FinishReason::kCANCELLED);
py::enum_<tle::KvCacheTransferMode>(m, "KvCacheTransferMode")
.value("DRAM", tle::KvCacheTransferMode::DRAM)
.value("GDS", tle::KvCacheTransferMode::GDS)
.value("POSIX_DEBUG_FALLBACK", tle::KvCacheTransferMode::POSIX_DEBUG_FALLBACK);
auto samplingConfigGetstate = [](tle::SamplingConfig const& self)
{
return py::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(),
self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(),
self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(),
self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(),
self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
};
auto samplingConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 20)
{
throw std::runtime_error("Invalid SamplingConfig state!");
}
return tle::SamplingConfig(state[0].cast<SizeType32>(), // BeamWidth
state[1].cast<std::optional<SizeType32>>(), // TopK
state[2].cast<std::optional<FloatType>>(), // TopP
state[3].cast<std::optional<FloatType>>(), // TopPMin
state[4].cast<std::optional<tle::TokenIdType>>(), // TopPResetIds
state[5].cast<std::optional<FloatType>>(), // TopPDecay
state[6].cast<std::optional<tle::RandomSeedType>>(), // Seed
state[7].cast<std::optional<FloatType>>(), // Temperature
state[8].cast<std::optional<SizeType32>>(), // MinTokens
state[9].cast<std::optional<FloatType>>(), // BeamSearchDiversityRate
state[10].cast<std::optional<FloatType>>(), // RepetitionPenalty
state[11].cast<std::optional<FloatType>>(), // PresencePenalty
state[12].cast<std::optional<FloatType>>(), // FrequencyPenalty
state[13].cast<std::optional<SizeType32>>(), // PromptIgnoreLength
state[14].cast<std::optional<FloatType>>(), // LengthPenalty
state[15].cast<std::optional<SizeType32>>(), // EarlyStopping
state[16].cast<std::optional<SizeType32>>(), // NoRepeatNgramSize
state[17].cast<std::optional<SizeType32>>(), // NumReturnSequences
state[18].cast<std::optional<FloatType>>(), // MinP
state[19].cast<std::optional<std::vector<SizeType32>>>() // BeamWidthArray
);
};
py::class_<tle::SamplingConfig>(m, "SamplingConfig")
.def(py::init<tle::SizeType32,
std::optional<tle::SizeType32> const&, // beamWidth
std::optional<tle::FloatType> const&, // topP
std::optional<tle::FloatType> const&, // topPMin
std::optional<tle::TokenIdType> const&, // topPResetIds
std::optional<tle::FloatType> const&, // topPDecay
std::optional<tle::RandomSeedType> const&, // seed
std::optional<tle::FloatType> const&, // temperature
std::optional<tle::SizeType32> const&, // minTokens
std::optional<tle::FloatType> const&, // beamSearchDiversityRate
std::optional<tle::FloatType> const&, // repetitionPenalty
std::optional<tle::FloatType> const&, // presencePenalty
std::optional<tle::FloatType> const&, // frequencyPenalty
std::optional<tle::SizeType32> const&, // promptIgnoreLength
std::optional<tle::FloatType> const&, // lengthPenalty
std::optional<tle::SizeType32> const&, // earlyStopping
std::optional<tle::SizeType32> const&, // noRepeatNgramSize
std::optional<tle::SizeType32> const&, // numReturnSequences
std::optional<tle::FloatType> const&, // minP
std::optional<std::vector<tle::SizeType32>> const& // beamWidthArray
>(),
// clang-format off
py::arg("beam_width") = 1,
py::kw_only(),
py::arg("top_k") = py::none(),
py::arg("top_p") = py::none(),
py::arg("top_p_min") = py::none(),
py::arg("top_p_reset_ids") = py::none(),
py::arg("top_p_decay") = py::none(),
py::arg("seed") = py::none(),
py::arg("temperature") = py::none(),
py::arg("min_tokens") = py::none(),
py::arg("beam_search_diversity_rate") = py::none(),
py::arg("repetition_penalty") = py::none(),
py::arg("presence_penalty") = py::none(),
py::arg("frequency_penalty") = py::none(),
py::arg("prompt_ignore_length") = py::none(),
py::arg("length_penalty") = py::none(),
py::arg("early_stopping") = py::none(),
py::arg("no_repeat_ngram_size") = py::none(),
py::arg("num_return_sequences") = py::none(),
py::arg("min_p") = py::none(),
py::arg("beam_width_array") = py::none()) // clang-format on
.def_property("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth)
.def_property("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK)
.def_property("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP)
.def_property("top_p_min", &tle::SamplingConfig::getTopPMin, &tle::SamplingConfig::setTopPMin)
.def_property("top_p_reset_ids", &tle::SamplingConfig::getTopPResetIds, &tle::SamplingConfig::setTopPResetIds)
.def_property("top_p_decay", &tle::SamplingConfig::getTopPDecay, &tle::SamplingConfig::setTopPDecay)
.def_property("seed", &tle::SamplingConfig::getSeed, &tle::SamplingConfig::setSeed)
.def_property("temperature", &tle::SamplingConfig::getTemperature, &tle::SamplingConfig::setTemperature)
.def_property("min_tokens", &tle::SamplingConfig::getMinTokens, &tle::SamplingConfig::setMinTokens)
.def_property("beam_search_diversity_rate", &tle::SamplingConfig::getBeamSearchDiversityRate,
&tle::SamplingConfig::setBeamSearchDiversityRate)
.def_property("repetition_penalty", &tle::SamplingConfig::getRepetitionPenalty,
&tle::SamplingConfig::setRepetitionPenalty)
.def_property("presence_penalty", &tle::SamplingConfig::getPresencePenalty,
[](tle::SamplingConfig& self, std::optional<FloatType> v) { self.setPresencePenalty(v); })
.def_property(
"frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty)
.def_property("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength,
&tle::SamplingConfig::setPromptIgnoreLength)
.def_property("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty)
.def_property("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping)
.def_property("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize,
&tle::SamplingConfig::setNoRepeatNgramSize)
.def_property("num_return_sequences", &tle::SamplingConfig::getNumReturnSequences,
&tle::SamplingConfig::setNumReturnSequences)
.def_property("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP)
.def_property(
"beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray)
.def(py::pickle(samplingConfigGetstate, samplingConfigSetstate));
auto additionalModelOutputGetstate
= [](tle::AdditionalModelOutput const& self) { return py::make_tuple(self.name, self.gatherContext); };
auto additionalModelOutputSetstate = [](py::tuple const& state)
{
if (state.size() != 2)
{
throw std::runtime_error("Invalid AdditionalModelOutput state!");
}
return tle::AdditionalModelOutput(state[0].cast<std::string>(), state[1].cast<bool>());
};
py::class_<tle::AdditionalModelOutput>(m, "AdditionalModelOutput")
.def(py::init<std::string, bool>(), py::arg("name"), py::arg("gather_context") = false)
.def_readwrite("name", &tle::AdditionalModelOutput::name)
.def_readwrite("gather_context", &tle::AdditionalModelOutput::gatherContext)
.def(py::pickle(additionalModelOutputGetstate, additionalModelOutputSetstate));
auto outputConfigGetstate = [](tle::OutputConfig const& self)
{
return py::make_tuple(self.returnLogProbs, self.returnContextLogits, self.returnGenerationLogits,
self.excludeInputFromOutput, self.returnEncoderOutput, self.returnPerfMetrics, self.additionalModelOutputs);
};
auto outputConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 7)
{
throw std::runtime_error("Invalid OutputConfig state!");
}
return tle::OutputConfig(state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<bool>(),
state[3].cast<bool>(), state[4].cast<bool>(), state[5].cast<bool>(),
state[6].cast<std::optional<std::vector<tle::AdditionalModelOutput>>>());
};
py::class_<tle::OutputConfig>(m, "OutputConfig")
.def(py::init<bool, bool, bool, bool, bool, bool, std::optional<std::vector<tle::AdditionalModelOutput>>>(),
py::arg("return_log_probs") = false, py::arg("return_context_logits") = false,
py::arg("return_generation_logits") = false, py::arg("exclude_input_from_output") = false,
py::arg("return_encoder_output") = false, py::arg("return_perf_metrics") = false,
py::arg("additional_model_outputs") = py::none())
.def_readwrite("return_log_probs", &tle::OutputConfig::returnLogProbs)
.def_readwrite("return_context_logits", &tle::OutputConfig::returnContextLogits)
.def_readwrite("return_generation_logits", &tle::OutputConfig::returnGenerationLogits)
.def_readwrite("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput)
.def_readwrite("return_encoder_output", &tle::OutputConfig::returnEncoderOutput)
.def_readwrite("return_perf_metrics", &tle::OutputConfig::returnPerfMetrics)
.def_readwrite("additional_model_outputs", &tle::OutputConfig::additionalModelOutputs)
.def(py::pickle(outputConfigGetstate, outputConfigSetstate));
auto externalDraftTokensConfigGetstate = [](tle::ExternalDraftTokensConfig const& self)
{ return py::make_tuple(self.getTokens(), self.getLogits(), self.getAcceptanceThreshold()); };
auto externalDraftTokensConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 3)
{
throw std::runtime_error("Invalid ExternalDraftTokensConfig state!");
}
return tle::ExternalDraftTokensConfig(state[0].cast<VecTokens>(), state[1].cast<std::optional<Tensor>>(),
state[2].cast<std::optional<FloatType>>());
};
py::class_<tle::ExternalDraftTokensConfig>(m, "ExternalDraftTokensConfig")
.def(py::init<VecTokens, std::optional<Tensor>, std::optional<FloatType> const&, std::optional<bool>>(),
py::arg("tokens"), py::arg("logits") = py::none(), py::arg("acceptance_threshold") = py::none(),
py::arg("fast_logits") = py::none())
.def_property_readonly("tokens", &tle::ExternalDraftTokensConfig::getTokens)
.def_property_readonly("logits", &tle::ExternalDraftTokensConfig::getLogits)
.def_property_readonly("acceptance_threshold", &tle::ExternalDraftTokensConfig::getAcceptanceThreshold)
.def(py::pickle(externalDraftTokensConfigGetstate, externalDraftTokensConfigSetstate))
.def_property_readonly("fast_logits", &tle::ExternalDraftTokensConfig::getFastLogits);
auto promptTuningConfigGetstate = [](tle::PromptTuningConfig const& self)
{ return py::make_tuple(self.getEmbeddingTable(), self.getInputTokenExtraIds()); };
auto promptTuningConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 2)
{
throw std::runtime_error("Invalid PromptTuningConfig state!");
}
return tle::PromptTuningConfig(state[0].cast<Tensor>(), state[1].cast<std::optional<VecTokenExtraIds>>());
};
py::class_<tle::PromptTuningConfig>(m, "PromptTuningConfig")
.def(py::init<Tensor, std::optional<VecTokenExtraIds>>(), py::arg("embedding_table"),
py::arg("input_token_extra_ids") = py::none())
.def_property_readonly("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable)
.def_property_readonly("input_token_extra_ids", &tle::PromptTuningConfig::getInputTokenExtraIds)
.def(py::pickle(promptTuningConfigGetstate, promptTuningConfigSetstate));
auto loraConfigGetstate = [](tle::LoraConfig const& self)
{ return py::make_tuple(self.getTaskId(), self.getWeights(), self.getConfig()); };
auto loraConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 3)
{
throw std::runtime_error("Invalid LoraConfig state!");
}
return tle::LoraConfig(
state[0].cast<IdType>(), state[1].cast<std::optional<Tensor>>(), state[2].cast<std::optional<Tensor>>());
};
py::class_<tle::LoraConfig>(m, "LoraConfig")
.def(py::init<uint64_t, std::optional<Tensor>, std::optional<Tensor>>(), py::arg("task_id"),
py::arg("weights") = py::none(), py::arg("config") = py::none())
.def_property_readonly("task_id", &tle::LoraConfig::getTaskId)
.def_property_readonly("weights", &tle::LoraConfig::getWeights)
.def_property_readonly("config", &tle::LoraConfig::getConfig)
.def(py::pickle(loraConfigGetstate, loraConfigSetstate));
auto multimodalInputGetstate = [](tle::MultimodalInput const& self)
{ return py::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths()); };
auto multimodalInputSetstate = [](py::tuple const& state)
{
if (state.size() != 3)
{
throw std::runtime_error("Invalid MultimodalInput state!");
}
return tle::MultimodalInput(state[0].cast<std::vector<std::vector<SizeType32>>>(),
state[1].cast<std::vector<SizeType32>>(), state[2].cast<std::vector<SizeType32>>());
};
py::class_<tle::MultimodalInput>(m, "MultimodalInput")
.def(py::init<std::vector<std::vector<SizeType32>>, std::vector<SizeType32>, std::vector<SizeType32>>(),
py::arg("multimodal_hashes"), py::arg("multimodal_positions"), py::arg("multimodal_lengths"))
.def_property_readonly("multimodal_hashes", &tle::MultimodalInput::getMultimodalHashes)
.def_property_readonly("multimodal_positions", &tle::MultimodalInput::getMultimodalPositions)
.def_property_readonly("multimodal_lengths", &tle::MultimodalInput::getMultimodalLengths)
.def(py::pickle(multimodalInputGetstate, multimodalInputSetstate));
auto MropeConfigGetstate = [](tle::MropeConfig const& self)
{ return py::make_tuple(self.getMRopeRotaryCosSin(), self.getMRopePositionDeltas()); };
auto MropeConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 2)
{
throw std::runtime_error("Invalid MropeConfig state!");
}
return tle::MropeConfig(state[0].cast<tle::Tensor>(), state[1].cast<SizeType32>());
};
py::class_<tle::MropeConfig>(m, "MropeConfig")
.def(py::init<Tensor, SizeType32>(), py::arg("mrope_rotary_cos_sin"), py::arg("mrope_position_deltas"))
.def_property_readonly("mrope_rotary_cos_sin", &tle::MropeConfig::getMRopeRotaryCosSin)
.def_property_readonly("mrope_position_deltas", &tle::MropeConfig::getMRopePositionDeltas)
.def(py::pickle(MropeConfigGetstate, MropeConfigSetstate));
auto lookaheadDecodingConfigGetstate = [](tle::LookaheadDecodingConfig const& self)
{ return py::make_tuple(self.getWindowSize(), self.getNgramSize(), self.getVerificationSetSize()); };
auto lookaheadDecodingConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 3)
{
throw std::runtime_error("Invalid LookaheadDecodingConfig state!");
}
return tle::LookaheadDecodingConfig(
state[0].cast<SizeType32>(), state[1].cast<SizeType32>(), state[2].cast<SizeType32>());
};
py::class_<tle::LookaheadDecodingConfig>(m, "LookaheadDecodingConfig")
.def(py::init<SizeType32, SizeType32, SizeType32>(), py::arg("max_window_size"), py::arg("max_ngram_size"),
py::arg("max_verification_set_size"))
.def_property_readonly("max_window_size", &tle::LookaheadDecodingConfig::getWindowSize)
.def_property_readonly("max_ngram_size", &tle::LookaheadDecodingConfig::getNgramSize)
.def_property_readonly("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize)
.def("calculate_speculative_resource", &tle::LookaheadDecodingConfig::calculateSpeculativeResource)
.def_static(
"calculate_speculative_resource_tuple", &tle::LookaheadDecodingConfig::calculateSpeculativeResourceTuple)
.def(py::pickle(lookaheadDecodingConfigGetstate, lookaheadDecodingConfigSetstate))
.def_static("get_default_lookahead_decoding_window",
[]() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingWindow; })
.def_static("get_default_lookahead_decoding_ngram",
[]() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingNgram; })
.def_static("get_default_lookahead_decoding_verification_set",
[]() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingVerificationSet; });
auto TokenRangeRetentionConfigGetstate = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig const& self)
{ return py::make_tuple(self.tokenStart, self.tokenEnd, self.priority, self.durationMs); };
auto TokenRangeRetentionConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 4)
{
throw std::runtime_error("Invalid state!");
}
return tle::KvCacheRetentionConfig::TokenRangeRetentionConfig(state[0].cast<SizeType32>(),
state[1].cast<std::optional<SizeType32>>(), state[2].cast<tle::RetentionPriority>(),
state[3].cast<std::optional<std::chrono::milliseconds>>());
};
auto kvCacheRetentionConfigGetstate = [](tle::KvCacheRetentionConfig const& self)
{
return py::make_tuple(self.getTokenRangeRetentionConfigs(), self.getDecodeRetentionPriority(),
self.getDecodeDurationMs(), self.getTransferMode(), self.getDirectory());
};
auto kvCacheRetentionConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 5)
{
throw std::runtime_error("Invalid state!");
}
return tle::KvCacheRetentionConfig(
state[0].cast<std::vector<tle::KvCacheRetentionConfig::TokenRangeRetentionConfig>>(),
state[1].cast<tle::RetentionPriority>(), state[2].cast<std::optional<std::chrono::milliseconds>>(),
state[3].cast<tle::KvCacheTransferMode>(), state[4].cast<std::string>());
};
auto kvCacheRetentionConfig = py::class_<tle::KvCacheRetentionConfig>(m, "KvCacheRetentionConfig");
py::class_<tle::KvCacheRetentionConfig::TokenRangeRetentionConfig>(
kvCacheRetentionConfig, "TokenRangeRetentionConfig")
.def(py::init<SizeType32, std::optional<SizeType32>, tle::RetentionPriority,
std::optional<std::chrono::milliseconds>>(),
py::arg("token_start"), py::arg("token_end"), py::arg("priority"), py::arg("duration_ms") = py::none())
.def_readwrite("token_start", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenStart)
.def_readwrite("token_end", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenEnd)
.def_readwrite("priority", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::priority)
.def_readwrite("duration_ms", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::durationMs)
.def(py::pickle(TokenRangeRetentionConfigGetstate, TokenRangeRetentionConfigSetstate))
.def("__eq__", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::operator==);
// There's a circular dependency between the declaration of the TokenRangeRetentionPriority and
// KvCacheRetentionConfig bindings. Defer definition of the KvCacheRetentionConfig bindings until the
// TokenRangeRetentionPriority bindings have been defined.
kvCacheRetentionConfig
.def(py::init<std::vector<tle::KvCacheRetentionConfig::TokenRangeRetentionConfig>, tle::RetentionPriority,
std::optional<std::chrono::milliseconds>, tle::KvCacheTransferMode, std::string>(),
py::arg("token_range_retention_configs"),
py::arg("decode_retention_priority") = tle::KvCacheRetentionConfig::kDefaultRetentionPriority,
py::arg("decode_duration_ms") = py::none(),
py::arg_v("transfer_mode", tle::KvCacheTransferMode::DRAM, "DRAM"), py::arg("directory") = py::none())
.def_property_readonly(
"token_range_retention_configs", &tle::KvCacheRetentionConfig::getTokenRangeRetentionConfigs)
.def_property_readonly("decode_retention_priority", &tle::KvCacheRetentionConfig::getDecodeRetentionPriority)
.def_property_readonly("decode_duration_ms", &tle::KvCacheRetentionConfig::getDecodeDurationMs)
.def_property_readonly("transfer_mode", &tle::KvCacheRetentionConfig::getTransferMode)
.def_property_readonly("directory", &tle::KvCacheRetentionConfig::getDirectory)
.def(py::pickle(kvCacheRetentionConfigGetstate, kvCacheRetentionConfigSetstate))
.def("__eq__", &tle::KvCacheRetentionConfig::operator==);
auto ContextPhaseParamsGetState = [](tle::ContextPhaseParams const& self)
{
if (self.getState() != nullptr)
{
auto serializedState = self.getSerializedState();
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(),
py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getCtxDpRank(),
self.getDisaggInfoEndpoint());
}
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens(),
self.getCtxDpRank(), self.getDisaggInfoEndpoint());
};
auto ContextPhaseParamsSetState = [](py::tuple const& state)
{
if (state.size() != 6)
{
throw std::runtime_error("Invalid ContextPhaseParams state!");
}
if (!state[2].is_none())
{
auto opaque_state = state[2].cast<py::bytes>();
auto opaque_state_str_view = std::string_view(opaque_state.cast<std::string_view>());
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
state[1].cast<tle::ContextPhaseParams::RequestIdType>(),
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
state[3].cast<std::optional<VecTokens>>(), state[4].cast<std::optional<SizeType32>>(),
state[5].cast<std::optional<std::string>>());
}
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
state[1].cast<tle::ContextPhaseParams::RequestIdType>(), state[3].cast<std::optional<VecTokens>>(),
state[4].cast<std::optional<SizeType32>>(), state[5].cast<std::optional<std::string>>());
};
py::class_<tle::ContextPhaseParams>(m, "ContextPhaseParams")
.def(py::init(
[](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id,
std::optional<py::bytes> const& opaque_state, std::optional<VecTokens> const& draft_tokens,
std::optional<SizeType32> const& ctx_dp_rank,
std::optional<std::string> const& disagg_info_endpoint)
{
if (opaque_state)
{
auto opaque_state_str_view = std::string_view(opaque_state.value().cast<std::string_view>());
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id,
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
draft_tokens, ctx_dp_rank, disagg_info_endpoint);
}
return std::make_unique<tle::ContextPhaseParams>(
first_gen_tokens, req_id, draft_tokens, ctx_dp_rank, disagg_info_endpoint);
}),
py::arg("first_gen_tokens"), py::arg("req_id"), py::arg("opaque_state") = py::none(),
py::arg("draft_tokens") = py::none(), py::arg("ctx_dp_rank") = py::none(),
py::arg("disagg_info_endpoint") = py::none())
.def_property("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens,
&tle::ContextPhaseParams::setFirstGenTokens)
.def_property(
"draft_tokens", &tle::ContextPhaseParams::getDraftTokens, &tle::ContextPhaseParams::setDraftTokens)
.def_property("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId)
.def_property("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank)
.def_property("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint,
&tle::ContextPhaseParams::setDisaggInfoEndpoint)
.def_property_readonly("opaque_state",
[](tle::ContextPhaseParams const& self)
{
std::optional<py::bytes> opaque_state{std::nullopt};
if (self.getState() != nullptr)
{
auto serializedState = self.getSerializedState();
opaque_state = py::bytes(serializedState.data(), serializedState.size());
}
return opaque_state;
})
.def(py::pickle(ContextPhaseParamsGetState, ContextPhaseParamsSetState));
auto EagleDecodingConfigGetstate = [](tle::EagleConfig const& self)
{
return py::make_tuple(self.getEagleChoices(), self.isGreedySampling(), self.getPosteriorThreshold(),
self.useDynamicTree(), self.getDynamicTreeMaxTopK());
};
auto EagleDecodingConfigSetstate = [](py::tuple const& state)
{
if (state.size() != 5)
{
throw std::runtime_error("Invalid EagleConfig state!");
}
return tle::EagleConfig(state[0].cast<std::optional<tle::EagleChoices>>(), state[1].cast<bool>(),
state[2].cast<std::optional<float>>(), state[3].cast<bool>(), state[4].cast<std::optional<SizeType32>>());
};
py::class_<tle::EagleConfig>(m, "EagleConfig")
.def(py::init<std::optional<tle::EagleChoices>, bool, std::optional<float>, bool, std::optional<SizeType32>>(),
py::arg("eagle_choices") = py::none(), py::arg("greedy_sampling") = true,
py::arg("posterior_threshold") = py::none(), py::arg("use_dynamic_tree") = false,
py::arg("dynamic_tree_max_topK") = py::none())
.def_property_readonly("eagle_choices", &tle::EagleConfig::getEagleChoices)
.def_property_readonly("greedy_sampling", &tle::EagleConfig::isGreedySampling)
.def_property_readonly("posterior_threshold", &tle::EagleConfig::getPosteriorThreshold)
.def_property_readonly("use_dynamic_tree", &tle::EagleConfig::useDynamicTree)
.def_property_readonly("dynamic_tree_max_topK", &tle::EagleConfig::getDynamicTreeMaxTopK)
.def(py::pickle(EagleDecodingConfigGetstate, EagleDecodingConfigSetstate));
// Guided decoding params
auto pyGuidedDecodingParams = py::class_<tle::GuidedDecodingParams>(m, "GuidedDecodingParams");
py::enum_<tle::GuidedDecodingParams::GuideType>(pyGuidedDecodingParams, "GuideType")
.value("JSON", tle::GuidedDecodingParams::GuideType::kJSON)
.value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA)
.value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX)
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR)
.value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG);
auto guidedDecodingParamsGetstate
= [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); };
auto guidedDecodingParamsSetstate = [](py::tuple state)
{
if (state.size() != 2)
{
throw std::runtime_error("Invalid GuidedDecodingParams state!");
}
return tle::GuidedDecodingParams(
state[0].cast<tle::GuidedDecodingParams::GuideType>(), state[1].cast<std::optional<std::string>>());
};
pyGuidedDecodingParams
.def(py::init<tle::GuidedDecodingParams::GuideType, std::optional<std::string>>(), py::arg("guide_type"),
py::arg("guide") = py::none())
.def_property_readonly("guide_type", &tle::GuidedDecodingParams::getGuideType)
.def_property_readonly("guide", &tle::GuidedDecodingParams::getGuide)
.def(py::pickle(guidedDecodingParamsGetstate, guidedDecodingParamsSetstate));
auto requestGetstate = [](tle::Request const& self)
{
return py::make_tuple(self.getInputTokenIds(), self.getMaxTokens(), self.getStreaming(),
self.getSamplingConfig(), self.getOutputConfig(), self.getEndId(), self.getPadId(), self.getPositionIds(),
self.getBadWords(), self.getStopWords(), self.getEmbeddingBias(), self.getExternalDraftTokensConfig(),
self.getPromptTuningConfig(), self.getMultimodalInput(), self.getMultimodalEmbedding(),
self.getMropeConfig(), self.getLoraConfig(), self.getLookaheadConfig(), self.getKvCacheRetentionConfig(),
self.getLogitsPostProcessorName(), self.getLogitsPostProcessor(), self.getEncoderInputTokenIds(),
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
self.getGuidedDecodingParams(), self.getCacheSaltID(), self.getDisaggRequestId());
};
auto requestSetstate = [](py::tuple const& state)
{
if (state.size() != 35)
{
throw std::runtime_error("Invalid Request state!");
}
return std::make_unique<tle::Request>(state[0].cast<VecTokens>(), state[1].cast<SizeType32>(),
state[2].cast<bool>(), state[3].cast<tle::SamplingConfig>(), state[4].cast<tle::OutputConfig>(),
state[5].cast<std::optional<SizeType32>>(), state[6].cast<std::optional<SizeType32>>(),
state[7].cast<std::optional<std::vector<SizeType32>>>(),
state[8].cast<std::optional<std::list<VecTokens>>>(), state[9].cast<std::optional<std::list<VecTokens>>>(),
state[10].cast<std::optional<Tensor>>(), state[11].cast<std::optional<tle::ExternalDraftTokensConfig>>(),
state[12].cast<std::optional<tle::PromptTuningConfig>>(),
state[13].cast<std::optional<tle::MultimodalInput>>(), state[14].cast<std::optional<Tensor>>(),
state[15].cast<std::optional<tle::MropeConfig>>(), state[16].cast<std::optional<tle::LoraConfig>>(),
state[17].cast<std::optional<tle::LookaheadDecodingConfig>>(),
state[18].cast<std::optional<tle::KvCacheRetentionConfig>>(), state[19].cast<std::optional<std::string>>(),
state[20].cast<std::optional<tle::LogitsPostProcessor>>(), state[21].cast<std::optional<VecTokens>>(),
state[22].cast<std::optional<IdType>>(), state[23].cast<bool>(), state[24].cast<tle::PriorityType>(),
state[25].cast<tle::RequestType>(), state[26].cast<std::optional<tle::ContextPhaseParams>>(),
state[27].cast<std::optional<tle::Tensor>>(), state[28].cast<std::optional<SizeType32>>(),
state[29].cast<std::optional<tle::Tensor>>(), 1, state[30].cast<std::optional<tle::EagleConfig>>(),
state[31].cast<std::optional<tle::Tensor>>(), state[32].cast<std::optional<tle::GuidedDecodingParams>>(),
std::nullopt, std::nullopt, state[33].cast<std::optional<tle::CacheSaltIDType>>(),
state[34].cast<std::optional<tle::IdType>>());
};
py::class_<tle::Request> request(m, "Request", pybind11::dynamic_attr());
request
.def(py::init<tle::VecTokens, // inputTokenIds
tle::SizeType32, // maxTokens
bool, // streaming
tle::SamplingConfig const&, // samplingConfig
tle::OutputConfig const&, // outputConfig
std::optional<tle::SizeType32> const&, // endId
std::optional<tle::SizeType32> const&, // padId
std::optional<std::vector<SizeType32>>, // positionIds
std::optional<std::list<tle::VecTokens>>, // badWords
std::optional<std::list<tle::VecTokens>>, // stopWords
std::optional<tle::Tensor>, // embeddingBias
std::optional<tle::ExternalDraftTokensConfig>, // externalDraftTokensConfig
std::optional<tle::PromptTuningConfig>, // pTuningConfig
std::optional<tle::MultimodalInput>, // multimodalInput
std::optional<tle::Tensor>, // multimodalEmbedding
std::optional<tle::MropeConfig>, // mRopeConfig
std::optional<tle::LoraConfig>, // loraConfig
std::optional<tle::LookaheadDecodingConfig>, // lookaheadConfig
std::optional<tle::KvCacheRetentionConfig>, // kvCacheRetentionConfig
std::optional<std::string>, // logitsPostProcessorName
std::optional<tle::LogitsPostProcessor>, // logitsPostProcessor
std::optional<tle::VecTokens>, // encoderInputTokenIds
std::optional<tle::IdType>, // clientId
bool, // returnAllGeneratedTokens
tle::PriorityType, // priority
tle::RequestType, // type
std::optional<tle::ContextPhaseParams>, // contextPhaseParams
std::optional<tle::Tensor>, // encoderInputFeatures
std::optional<tle::SizeType32>, // encoderOutputLength
std::optional<tle::Tensor>, // crossAttentionMask
SizeType32, // numReturnSequences
std::optional<tle::EagleConfig>, // eagleConfig
std::optional<tle::Tensor>, // skipCrossAttnBlocks
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
std::optional<tle::SizeType32>, // languageAdapterUid
std::optional<tle::MillisecondsType>, // allottedTimeMs
std::optional<tle::CacheSaltIDType>, // cacheSaltID
std::optional<tle::IdType> // disaggRequestId
>(),
// clang-format off
py::arg("input_token_ids"),
py::arg("max_tokens"),
py::kw_only(),
py::arg("streaming") = false,
py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"),
py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"),
py::arg("end_id") = py::none(),
py::arg("pad_id") = py::none(),
py::arg("position_ids") = py::none(),
py::arg("bad_words") = py::none(),
py::arg("stop_words") = py::none(),
py::arg("embedding_bias") = py::none(),
py::arg("external_draft_tokens_config") = py::none(),
py::arg("prompt_tuning_config") = py::none(),
py::arg("multimodal_input") = py::none(),
py::arg("multimodal_embedding") = py::none(),
py::arg("mrope_config") = py::none(),
py::arg("lora_config") = py::none(),
py::arg("lookahead_config") = py::none(),
py::arg("kv_cache_retention_config") = py::none(),
py::arg("logits_post_processor_name") = py::none(),
py::arg("logits_post_processor") = py::none(),
py::arg("encoder_input_token_ids") = py::none(),
py::arg("client_id") = py::none(),
py::arg("return_all_generated_tokens") = false,
py::arg("priority") = tle::Request::kDefaultPriority,
py::arg_v("type", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION,
"RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION"),
py::arg("context_phase_params") = py::none(),
py::arg("encoder_input_features") = py::none(),
py::arg("encoder_output_length") = py::none(),
py::arg("cross_attention_mask") = py::none(),
py::arg("num_return_sequences") = 1,
py::arg("eagle_config") = py::none(),
py::arg("skip_cross_attn_blocks") = py::none(),
py::arg("guided_decoding_params") = py::none(),
py::arg("language_adapter_uid") = py::none(),
py::arg("allotted_time_ms") = py::none(),
py::arg("cache_salt_id") = py::none(),
py::arg("disagg_request_id") = py::none()
) // clang-format on
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
.def_property_readonly("max_tokens", &tle::Request::getMaxTokens)
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
.def_property("sampling_config", &tle::Request::getSamplingConfig, &tle::Request::setSamplingConfig)
.def_property("output_config", &tle::Request::getOutputConfig, &tle::Request::setOutputConfig)
.def_property("end_id", &tle::Request::getEndId, &tle::Request::setEndId)
.def_property("pad_id", &tle::Request::getPadId, &tle::Request::setPadId)
.def_property("position_ids", &tle::Request::getPositionIds, &tle::Request::setPositionIds)
.def_property("bad_words", &tle::Request::getBadWords, &tle::Request::setBadWords)
.def_property("stop_words", &tle::Request::getStopWords, &tle::Request::setStopWords)
.def_property("embedding_bias", &tle::Request::getEmbeddingBias, &tle::Request::setEmbeddingBias)
.def_property("external_draft_tokens_config", &tle::Request::getExternalDraftTokensConfig,
&tle::Request::setExternalDraftTokensConfig)
.def_property(
"prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig)
.def_property("multimodal_input", &tle::Request::getMultimodalInput, &tle::Request::setMultimodalInput)
.def_property(
"multimodal_embedding", &tle::Request::getMultimodalEmbedding, &tle::Request::setMultimodalEmbedding)
.def_property("mrope_config", &tle::Request::getMropeConfig, &tle::Request::setMropeConfig)
.def_property("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig)
.def_property("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig)
.def_property("kv_cache_retention_config", &tle::Request::getKvCacheRetentionConfig,
&tle::Request::setKvCacheRetentionConfig)
.def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName,
&tle::Request::setLogitsPostProcessorName)
.def_property(
"logits_post_processor", &tle::Request::getLogitsPostProcessor, &tle::Request::setLogitsPostProcessor)
.def_property(
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds)
.def_property("client_id", &tle::Request::getClientId, &tle::Request::setClientId)
.def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens,
&tle::Request::setReturnAllGeneratedTokens)
.def_property("request_type", &tle::Request::getRequestType, &tle::Request::setRequestType)
.def_property(
"encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures)
.def_property(
"cross_attention_mask", &tle::Request::getCrossAttentionMask, &tle::Request::setCrossAttentionMask)
.def_property("eagle_config", &tle::Request::getEagleConfig, &tle::Request::setEagleConfig)
.def_property(
"skip_cross_attn_blocks", &tle::Request::getSkipCrossAttnBlocks, &tle::Request::setSkipCrossAttnBlocks)
.def_property(
"guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams)
.def_property("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs)
.def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
.def_property(
"context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
.def_property("disagg_request_id", &tle::Request::getDisaggRequestId, &tle::Request::setDisaggRequestId)
.def(py::pickle(requestGetstate, requestSetstate));
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
py::class_<tle::SpeculativeDecodingFastLogitsInfo>(m, "SpeculativeDecodingFastLogitsInfo")
.def(py::init<>())
.def_readwrite("draft_request_id", &tle::SpeculativeDecodingFastLogitsInfo::draftRequestId)
.def_readwrite("draft_participant_id", &tle::SpeculativeDecodingFastLogitsInfo::draftParticipantId)
.def("to_tensor", &tle::SpeculativeDecodingFastLogitsInfo::toTensor);
auto requestPerfMetrics = py::class_<tle::RequestPerfMetrics>(m, "RequestPerfMetrics");
auto timingMetricsGetstate = [](tle::RequestPerfMetrics::TimingMetrics const& self)
{
return py::make_tuple(self.arrivalTime, self.firstScheduledTime, self.firstTokenTime, self.lastTokenTime,
self.kvCacheTransferStart, self.kvCacheTransferEnd, self.kvCacheSize);
};
auto timingMetricsSetstate = [](py::tuple const& state)
{
if (state.size() != 7)
{
throw std::runtime_error("Invalid TimingMetrics state!");
}
return tle::RequestPerfMetrics::TimingMetrics{state[0].cast<tle::RequestPerfMetrics::TimePoint>(),
state[1].cast<tle::RequestPerfMetrics::TimePoint>(), state[2].cast<tle::RequestPerfMetrics::TimePoint>(),
state[3].cast<tle::RequestPerfMetrics::TimePoint>(), state[4].cast<tle::RequestPerfMetrics::TimePoint>(),
state[5].cast<tle::RequestPerfMetrics::TimePoint>(), state[6].cast<size_t>()};
};
py::class_<tle::RequestPerfMetrics::TimingMetrics>(m, "TimingMetrics")
.def(py::init<>())
.def_readwrite("arrival_time", &tle::RequestPerfMetrics::TimingMetrics::arrivalTime)
.def_readwrite("first_scheduled_time", &tle::RequestPerfMetrics::TimingMetrics::firstScheduledTime)
.def_readwrite("first_token_time", &tle::RequestPerfMetrics::TimingMetrics::firstTokenTime)
.def_readwrite("last_token_time", &tle::RequestPerfMetrics::TimingMetrics::lastTokenTime)
.def_readwrite("kv_cache_transfer_start", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferStart)
.def_readwrite("kv_cache_transfer_end", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferEnd)
.def_readwrite("kv_cache_size", &tle::RequestPerfMetrics::TimingMetrics::kvCacheSize)
.def(py::pickle(timingMetricsGetstate, timingMetricsSetstate));
auto kvCacheMetricsGetstate = [](tle::RequestPerfMetrics::KvCacheMetrics const& self)
{
return py::make_tuple(self.numTotalAllocatedBlocks, self.numNewAllocatedBlocks, self.numReusedBlocks,
self.numMissedBlocks, self.kvCacheHitRate);
};
auto kvCacheMetricsSetstate = [](py::tuple const& state)
{
if (state.size() != 5)
{
throw std::runtime_error("Invalid KvCacheMetrics state!");
}
return tle::RequestPerfMetrics::KvCacheMetrics{state[0].cast<SizeType32>(), state[1].cast<SizeType32>(),
state[2].cast<SizeType32>(), state[3].cast<SizeType32>(), state[4].cast<float>()};
};
py::class_<tle::RequestPerfMetrics::KvCacheMetrics>(m, "KvCacheMetrics")
.def(py::init<>())
.def_readwrite("num_total_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numTotalAllocatedBlocks)
.def_readwrite("num_new_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numNewAllocatedBlocks)
.def_readwrite("num_reused_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numReusedBlocks)
.def_readwrite("num_missed_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numMissedBlocks)
.def_readwrite("kv_cache_hit_rate", &tle::RequestPerfMetrics::KvCacheMetrics::kvCacheHitRate)
.def(py::pickle(kvCacheMetricsGetstate, kvCacheMetricsSetstate));
auto speculativeDecodingMetricsGetstate = [](tle::RequestPerfMetrics::SpeculativeDecodingMetrics const& self)
{ return py::make_tuple(self.acceptanceRate, self.totalAcceptedDraftTokens, self.totalDraftTokens); };
auto speculativeDecodingMetricsSetstate = [](py::tuple const& state)
{
if (state.size() != 3)
{
throw std::runtime_error("Invalid SpeculativeDecodingMetrics state!");
}
return tle::RequestPerfMetrics::SpeculativeDecodingMetrics{
state[0].cast<float>(), state[1].cast<SizeType32>(), state[2].cast<SizeType32>()};
};
py::class_<tle::RequestPerfMetrics::SpeculativeDecodingMetrics>(m, "SpeculativeDecodingMetrics")
.def(py::init<>())
.def_readwrite("acceptance_rate", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::acceptanceRate)
.def_readwrite("total_accepted_draft_tokens",
&tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalAcceptedDraftTokens)
.def_readwrite("total_draft_tokens", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalDraftTokens)
.def(py::pickle(speculativeDecodingMetricsGetstate, speculativeDecodingMetricsSetstate));
auto requestPerfMetricsGetstate = [](tle::RequestPerfMetrics const& self)
{
return py::make_tuple(self.timingMetrics, self.kvCacheMetrics, self.speculativeDecoding, self.firstIter,
self.lastIter, self.iter);
};
auto requestPerfMetricsSetstate = [](py::tuple const& state)
{
if (state.size() != 6)
{
throw std::runtime_error("Invalid RequestPerfMetrics state!");
}
return tle::RequestPerfMetrics{state[0].cast<tle::RequestPerfMetrics::TimingMetrics>(),
state[1].cast<tle::RequestPerfMetrics::KvCacheMetrics>(),
state[2].cast<tle::RequestPerfMetrics::SpeculativeDecodingMetrics>(),
state[3].cast<std::optional<tle::IterationType>>(), state[4].cast<std::optional<tle::IterationType>>(),
state[5].cast<std::optional<tle::IterationType>>()};
};
// There's a circular dependency between the declaration of the TimingMetrics and RequestPerfMetrics bindings.
// Defer definition of the RequestPerfMetrics bindings until the TimingMetrics have been defined.
requestPerfMetrics.def(py::init<>())
.def_readwrite("timing_metrics", &tle::RequestPerfMetrics::timingMetrics)
.def_readwrite("kv_cache_metrics", &tle::RequestPerfMetrics::kvCacheMetrics)
.def_readwrite("speculative_decoding", &tle::RequestPerfMetrics::speculativeDecoding)
.def_readwrite("first_iter", &tle::RequestPerfMetrics::firstIter)
.def_readwrite("last_iter", &tle::RequestPerfMetrics::lastIter)
.def_readwrite("iter", &tle::RequestPerfMetrics::iter)
.def(py::pickle(requestPerfMetricsGetstate, requestPerfMetricsSetstate));
py::class_<tle::AdditionalOutput>(m, "AdditionalOutput")
.def(py::init([](std::string const& name, tle::Tensor const& output)
{ return std::make_unique<tle::AdditionalOutput>(name, output); }))
.def_readwrite("name", &tle::AdditionalOutput::name)
.def_readwrite("output", &tle::AdditionalOutput::output);
auto resultSetstate = [](py::tuple const& state)
{
if (state.size() != 14)
{
throw std::runtime_error("Invalid Request state!");
}
tle::Result result;
result.isFinal = state[0].cast<bool>();
result.outputTokenIds = state[1].cast<std::vector<VecTokens>>();
result.cumLogProbs = state[2].cast<std::optional<std::vector<float>>>();
result.logProbs = state[3].cast<std::optional<std::vector<std::vector<float>>>>();
result.contextLogits = state[4].cast<std::optional<Tensor>>();
result.generationLogits = state[5].cast<std::optional<Tensor>>();
result.encoderOutput = state[6].cast<std::optional<Tensor>>();
result.finishReasons = state[7].cast<std::vector<tle::FinishReason>>();
result.sequenceIndex = state[8].cast<SizeType32>();
result.isSequenceFinal = state[9].cast<bool>();
result.decodingIter = state[10].cast<SizeType32>();
result.avgDecodedTokensPerIter = state[11].cast<float>();
result.contextPhaseParams = state[12].cast<std::optional<tle::ContextPhaseParams>>();
result.requestPerfMetrics = state[13].cast<std::optional<tle::RequestPerfMetrics>>();
return std::make_unique<tle::Result>(result);
};
auto resultGetstate = [](tle::Result const& self)
{
return py::make_tuple(self.isFinal, self.outputTokenIds, self.cumLogProbs, self.logProbs, self.contextLogits,
self.generationLogits, self.encoderOutput, self.finishReasons, self.sequenceIndex, self.isSequenceFinal,
self.decodingIter, self.avgDecodedTokensPerIter, self.contextPhaseParams, self.requestPerfMetrics);
};
py::class_<tle::Result>(m, "Result")
.def(py::init<>())
.def_readwrite("is_final", &tle::Result::isFinal)
.def_readwrite("output_token_ids", &tle::Result::outputTokenIds)
.def_readwrite("cum_log_probs", &tle::Result::cumLogProbs)
.def_readwrite("log_probs", &tle::Result::logProbs)
.def_readwrite("context_logits", &tle::Result::contextLogits)
.def_readwrite("generation_logits", &tle::Result::generationLogits)
.def_readwrite("spec_dec_fast_logits_info", &tle::Result::specDecFastLogitsInfo)
.def_readwrite("encoder_output", &tle::Result::encoderOutput)
.def_readwrite("finish_reasons", &tle::Result::finishReasons)
.def_readwrite("sequence_index", &tle::Result::sequenceIndex)
.def_readwrite("is_sequence_final", &tle::Result::isSequenceFinal)
.def_readwrite("decoding_iter", &tle::Result::decodingIter)
.def_readwrite("avg_decoded_tokens_per_iter", &tle::Result::avgDecodedTokensPerIter)
.def_readwrite("context_phase_params", &tle::Result::contextPhaseParams)
.def_readwrite("request_perf_metrics", &tle::Result::requestPerfMetrics)
.def_readwrite("additional_outputs", &tle::Result::additionalOutputs)
.def(py::pickle(resultGetstate, resultSetstate));
m.def("deserialize_result",
[](std::string& x)
{
std::istringstream is(x);
return tle::serialize_utils::deserialize<tle::Result>(is);
});
auto responseGetstate = [](tle::Response const& self)
{ return py::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); };
auto responseSetstate = [](py::tuple const& state)
{
if (state.size() != 3)
{
throw std::runtime_error("Invalid Request state!");
}
return std::make_unique<tle::Response>(
state[0].cast<IdType>(), state[1].cast<tle::Result>(), state[2].cast<IdType>());
};
py::class_<tle::Response>(m, "Response")
.def(py::init<IdType, std::string, std::optional<IdType>>(), py::arg("request_id"), py::arg("error_msg"),
py::arg("client_id") = std::nullopt)
.def(py::init<IdType, tle::Result, std::optional<IdType>>(), py::arg("request_id"), py::arg("result"),
py::arg("client_id") = std::nullopt)
.def_property_readonly("request_id", &tle::Response::getRequestId)
.def_property_readonly("client_id", &tle::Response::getClientId)
.def("has_error", &tle::Response::hasError)
.def_property_readonly("error_msg", &tle::Response::getErrorMsg)
.def_property_readonly("result", &tle::Response::getResult)
.def("clear_context_logits",
[](tle::Response& self)
{
if (!self.hasError())
{
auto& result = const_cast<tle::Result&>(self.getResult());
result.contextLogits.reset();
}
})
.def("clear_generation_logits",
[](tle::Response& self)
{
if (!self.hasError())
{
auto& result = const_cast<tle::Result&>(self.getResult());
result.generationLogits.reset();
}
})
.def(py::pickle(responseGetstate, responseSetstate));
}
} // namespace tensorrt_llm::pybind::executor

View File

@ -1,29 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::executor
{
// Register bindings for executor API.
void initRequestBindings(pybind11::module_& m);
} // namespace tensorrt_llm::pybind::executor

View File

@ -1,40 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include "tensorrt_llm/common/bindingUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
namespace tensorrt_llm::pybind::process_group
{
void initBindings(py::module_& m)
{
m.def("init_pg",
[](py::object world_pg_obj, py::object local_pg_obj, std::string const& pybind11_abi)
{
using Pg = c10d::ProcessGroup;
using E = py::error_already_set;
pg_utils::init_pg(common::get_intrusive_ptr<Pg, E>(world_pg_obj.ptr(), pybind11_abi),
common::get_intrusive_ptr<Pg, E>(local_pg_obj.ptr(), pybind11_abi));
});
}
} // namespace tensorrt_llm::pybind::process_group

View File

@ -1,27 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace tensorrt_llm::pybind::process_group
{
void initBindings(py::module_& m);
} // namespace tensorrt_llm::pybind::process_group

View File

@ -1,510 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include "hostfunc.h"
#include "moeBindings.h"
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h"
#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h"
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
#include "tensorrt_llm/kernels/delayStream.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/ipcUtils.h"
#include "tensorrt_llm/runtime/lookaheadBuffers.h"
#include "tensorrt_llm/runtime/loraCache.h"
#include "tensorrt_llm/runtime/mcastGPUBuffer.h"
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
#include "tensorrt_llm/runtime/tllmRuntime.h"
#include "tensorrt_llm/runtime/torchView.h"
#include "tensorrt_llm/runtime/virtualMemory.h"
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
namespace tr = tensorrt_llm::runtime;
namespace te = tensorrt_llm::executor;
namespace tb = tensorrt_llm::batch_manager;
class PyITensor : public tensorrt_llm::runtime::ITensor
{
public:
/* Inherit the constructors */
using ITensor::ITensor;
[[nodiscard]] void* data() override
{
PYBIND11_OVERRIDE_PURE(void*, /* Return type */
ITensor, /* Parent class */
data /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
[[nodiscard]] void const* data() const override
{
PYBIND11_OVERRIDE_PURE(void const*, /* Return type */
ITensor, /* Parent class */
data /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
[[nodiscard]] std::size_t getSize() const override
{
PYBIND11_OVERRIDE_PURE(std::size_t, /* Return type */
ITensor, /* Parent class */
getSize /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
[[nodiscard]] std::size_t getCapacity() const override
{
PYBIND11_OVERRIDE_PURE(std::size_t, /* Return type */
ITensor, /* Parent class */
getCapacity /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
[[nodiscard]] DataType getDataType() const override
{
PYBIND11_OVERRIDE_PURE(DataType, /* Return type */
ITensor, /* Parent class */
getDataType /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
[[nodiscard]] tr::MemoryType getMemoryType() const override
{
PYBIND11_OVERRIDE_PURE(tr::MemoryType, /* Return type */
ITensor, /* Parent class */
getMemoryType /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
[[nodiscard]] char const* getMemoryTypeName() const override
{
PYBIND11_OVERRIDE_PURE(char const*, /* Return type */
ITensor, /* Parent class */
getMemoryTypeName /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
virtual void resize(std::size_t newSize) override
{
PYBIND11_OVERRIDE_PURE(void, /* Return type */
ITensor, /* Parent class */
resize /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
void release() override
{
PYBIND11_OVERRIDE_PURE(void, /* Return type */
ITensor, /* Parent class */
release /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
[[nodiscard]] Shape const& getShape() const override
{
PYBIND11_OVERRIDE_PURE(Shape const&, /* Return type */
ITensor, /* Parent class */
getShape /* Name of function in C++ (must match Python name) */
/* Argument(s) */
);
}
void reshape(Shape const& dims) override
{
PYBIND11_OVERRIDE_PURE(void, /* Return type */
ITensor, /* Parent class */
reshape, /* Name of function in C++ (must match Python name) */
dims /* Argument(s) */
);
}
};
class PyIGptDecoder : public tr::IGptDecoder
{
public:
using tr::IGptDecoder::IGptDecoder; // Inherit constructors
void setup(tr::SamplingConfig const& samplingConfig, size_t batchSize,
tr::DecodingInput::TensorConstPtr const& batchSlots,
std::optional<tr::DecodingOutput> const& output = std::nullopt,
std::optional<nvinfer1::DataType> explicitDraftTokensDType = std::nullopt,
std::optional<std::vector<tr::ITensor::SharedConstPtr>> const& lookaheadPrompt = std::nullopt,
std::optional<std::vector<te::LookaheadDecodingConfig>> const& lookaheadAlgoConfigs = std::nullopt) override
{
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, setup, samplingConfig, batchSize, batchSlots, output,
explicitDraftTokensDType, lookaheadPrompt, lookaheadAlgoConfigs);
}
void forwardAsync(tr::DecodingOutput& output, tr::DecodingInput const& input) override
{
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, forwardAsync, output, input);
}
void forwardSync(tr::DecodingOutput& output, tr::DecodingInput const& input) override
{
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, forwardSync, output, input);
}
tr::SamplingConfig const& getSamplingConfig() override
{
PYBIND11_OVERRIDE_PURE(tr::SamplingConfig const&, IGptDecoder, getSamplingConfig);
}
void disableLookahead(std::optional<tr::SamplingConfig> const& samplingConfig, tr::SizeType32 batchSize,
tr::DecodingInput::TensorConstPtr batchSlots) override
{
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, disableLookahead, samplingConfig, batchSize, batchSlots);
}
};
namespace tensorrt_llm::pybind::runtime
{
void initBindings(pybind11::module_& m)
{
py::classh<tr::ITensor, PyITensor>(m, "ITensor").def(py::init());
py::class_<tr::LoraCache::TaskLayerModuleConfig>(m, "TaskLayerModuleConfig")
.def(py::init<>())
.def_readwrite("page_id", &tr::LoraCache::TaskLayerModuleConfig::pageId)
.def_readwrite("slot_idx", &tr::LoraCache::TaskLayerModuleConfig::slotIdx)
.def_readwrite("in_size", &tr::LoraCache::TaskLayerModuleConfig::inSize)
.def_readwrite("out_size", &tr::LoraCache::TaskLayerModuleConfig::outSize)
.def_readwrite("module_id", &tr::LoraCache::TaskLayerModuleConfig::moduleId)
.def_readwrite("layer_id", &tr::LoraCache::TaskLayerModuleConfig::layerId)
.def_readwrite("adapter_size", &tr::LoraCache::TaskLayerModuleConfig::adapterSize)
.def_readwrite("num_slots", &tr::LoraCache::TaskLayerModuleConfig::numSlots)
.def_readwrite("weights_in_pointer", &tr::LoraCache::TaskLayerModuleConfig::weightsInPointer)
.def_readwrite("weights_out_pointer", &tr::LoraCache::TaskLayerModuleConfig::weightsOutPointer)
.def_readwrite("scaling_vec_pointer", &tr::LoraCache::TaskLayerModuleConfig::scalingVecPointer)
.def(py::self == py::self);
py::class_<tr::CudaVirtualMemoryManager>(m, "CudaVirtualMemoryManager")
.def("release_with_tag", &tr::CudaVirtualMemoryManager::releaseWithTag, py::arg("tag"),
py::call_guard<py::gil_scoped_release>())
.def("materialize_with_tag", &tr::CudaVirtualMemoryManager::materializeWithTag, py::arg("tag"),
py::call_guard<py::gil_scoped_release>());
py::classh<tr::TllmRuntime>(m, "TllmRuntime")
.def(py::init(
[](std::filesystem::path engine_path, float gpu_weights_percent = 1.0f, bool use_shape_inference = true)
{
// Using default logger by passing nullptr
return new tr::TllmRuntime(
tr::RawEngine(engine_path), nullptr, gpu_weights_percent, use_shape_inference);
}))
.def(py::init(
[](py::buffer engine_buffer, float gpu_weights_percent = 1.0f, bool use_shape_inference = true)
{
py::buffer_info info = engine_buffer.request();
if (info.ndim != 1)
throw std::runtime_error("Expected 1-D array for engine buffer");
return new tr::TllmRuntime(
tr::RawEngine(info.ptr, info.shape[0]), nullptr, gpu_weights_percent, use_shape_inference);
}))
.def_property_readonly("num_contexts", &tr::TllmRuntime::getNbContexts)
.def_property_readonly("num_profiles", &tr::TllmRuntime::getNbProfiles)
.def("get_opt_profile_id", &tr::TllmRuntime::getOptProfileId, py::arg("num_tokens"), py::arg("split_points"),
py::call_guard<py::gil_scoped_release>())
.def("clear_contexts", &tr::TllmRuntime::clearContexts, py::call_guard<py::gil_scoped_release>())
.def("execute_context", &tr::TllmRuntime::executeContext, py::arg("context_id"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("stream_ptr", &tr::TllmRuntime::getStreamPtr)
.def_property_readonly("buffer_manager",
static_cast<tr::BufferManager& (tr::TllmRuntime::*) ()>(&tr::TllmRuntime::getBufferManager))
.def("set_layer_profiler", &tr::TllmRuntime::setLayerProfiler, py::call_guard<py::gil_scoped_release>())
.def("has_layer_profiler", &tr::TllmRuntime::hasLayerProfiler, py::arg("context_id"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("layer_profiler_info", &tr::TllmRuntime::getLayerProfileInfo)
.def("report_to_profiler", &tr::TllmRuntime::reportToProfiler, py::arg("context_id"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("logits_dtype_from_engine",
[](tr::TllmRuntime& self) { return self.getEngine().getTensorDataType("logits"); });
py::class_<tr::LookaheadDecodingBuffers>(m, "LookaheadDecodingBuffers")
.def(py::init<tr::SizeType32, tr::SizeType32, tr::BufferManager const&>(), py::arg("max_num_sequences"),
py::arg("max_tokens_per_step"), py::arg("buffer_manager"), py::call_guard<py::gil_scoped_release>())
.def_readwrite("generation_lengths", &tr::LookaheadDecodingBuffers::generationLengths)
.def_readwrite("position_offsets", &tr::LookaheadDecodingBuffers::positionOffsets)
.def_readwrite("packed_masks", &tr::LookaheadDecodingBuffers::packedMasks)
.def_readwrite("position_ids", &tr::LookaheadDecodingBuffers::positionIds);
py::class_<tr::ExplicitDraftTokensBuffers::Inputs>(m, "ExplicitDraftTokensBuffersInputs")
.def("create", &tr::ExplicitDraftTokensBuffers::Inputs::create, py::arg("max_num_sequences"),
py::arg("runtime"), py::arg("model_config"), py::arg("world_config"),
py::call_guard<py::gil_scoped_release>())
.def_readwrite("temperatures", &tr::ExplicitDraftTokensBuffers::Inputs::temperatures)
.def_readwrite("position_ids_base", &tr::ExplicitDraftTokensBuffers::Inputs::positionIdsBase)
.def_readwrite("generation_lengths", &tr::ExplicitDraftTokensBuffers::Inputs::generationLengths)
.def_readwrite("random_data_sample", &tr::ExplicitDraftTokensBuffers::Inputs::randomDataSample)
.def_readwrite("random_data_validation", &tr::ExplicitDraftTokensBuffers::Inputs::randomDataValidation)
.def_readwrite("draft_tokens", &tr::ExplicitDraftTokensBuffers::Inputs::draftTokens)
.def_readwrite("draft_indices", &tr::ExplicitDraftTokensBuffers::Inputs::draftIndices)
.def_readwrite("draft_probs", &tr::ExplicitDraftTokensBuffers::Inputs::draftProbs)
.def_readwrite("packed_masks", &tr::ExplicitDraftTokensBuffers::Inputs::packedMasks)
.def_readwrite("position_ids", &tr::ExplicitDraftTokensBuffers::Inputs::positionIds)
.def_readwrite("max_gen_length_host", &tr::ExplicitDraftTokensBuffers::Inputs::maxGenLengthHost)
.def_readwrite("generation_lengths_host", &tr::ExplicitDraftTokensBuffers::Inputs::generationLengthsHost);
py::class_<tr::DecodingInput>(m, "DecodingInput");
py::class_<tr::DecodingOutput>(m, "DecodingOutput");
py::class_<tr::CudaEvent>(m, "CudaEvent")
.def(py::init<unsigned int>(), py::arg("flags") = cudaEventDisableTiming,
py::call_guard<py::gil_scoped_release>())
.def("synchronize", &tr::CudaEvent::synchronize, py::call_guard<py::gil_scoped_release>());
py::class_<tr::IGptDecoder, PyIGptDecoder>(m, "IGptDecoder")
.def(
"setup",
[](tr::IGptDecoder& self, tr::SamplingConfig const& samplingConfig, size_t batchSize,
at::Tensor const& batchSlots, std::optional<tr::DecodingOutput> const& output = std::nullopt,
std::optional<nvinfer1::DataType> explicitDraftTokensDType = std::nullopt,
std::optional<std::vector<tr::ITensor::SharedConstPtr>> const& lookaheadPrompt = std::nullopt,
std::optional<std::vector<te::LookaheadDecodingConfig>> const& lookaheadAlgoConfigs = std::nullopt)
{
auto tensorPtrBatchSlots = tr::TorchView::of(batchSlots);
self.setup(samplingConfig, batchSize, std::move(tensorPtrBatchSlots), output, explicitDraftTokensDType,
lookaheadPrompt, lookaheadAlgoConfigs);
},
py::arg("sampling_config"), py::arg("batch_size"), py::arg("batch_slots"), py::arg("output") = std::nullopt,
py::arg("explicit_draft_tokens_d_type") = std::nullopt, py::arg("lookahead_prompt") = std::nullopt,
py::arg("lookahead_algo_configs") = std::nullopt, py::call_guard<py::gil_scoped_release>());
py::class_<tr::decoder::DecoderState>(m, "DecoderState")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def("setup", &tr::decoder::DecoderState::setup, py::arg("max_num_sequences"), py::arg("max_beam_width"),
py::arg("max_attention_window"), py::arg("sink_token_length"), py::arg("max_sequence_length"),
py::arg("dtype"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"),
py::call_guard<py::gil_scoped_release>())
.def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, py::arg("max_num_sequences"),
py::arg("max_beam_width"), py::arg("max_attention_window"), py::arg("buffer_manager"),
py::call_guard<py::gil_scoped_release>())
.def("setup_speculative_decoding", &tr::decoder::DecoderState::setupSpeculativeDecoding,
py::arg("speculative_decoding_mode"), py::arg("max_tokens_per_engine_step"), py::arg("dtype"),
py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("joint_decoding_input", &tr::decoder::DecoderState::getJointDecodingInput)
.def_property_readonly("joint_decoding_output", &tr::decoder::DecoderState::getJointDecodingOutput)
.def_property_readonly("cache_indirection_input", &tr::decoder::DecoderState::getCacheIndirectionInput)
.def_property_readonly("cache_indirection_output", &tr::decoder::DecoderState::getCacheIndirectionOutput)
.def_property_readonly(
"sequence_lengths", py::overload_cast<>(&tr::decoder::DecoderState::getSequenceLengths, py::const_))
.def("get_sequence_lengths",
py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getSequenceLengths, py::const_),
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly("all_new_tokens", &tr::decoder::DecoderState::getAllNewTokens)
.def_property_readonly("finished_sum", &tr::decoder::DecoderState::getFinishedSum)
.def_property_readonly("finish_reasons", &tr::decoder::DecoderState::getFinishReasons)
.def_property_readonly("ids", py::overload_cast<>(&tr::decoder::DecoderState::getIds, py::const_))
.def("get_ids", py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getIds, py::const_),
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"gathered_ids", py::overload_cast<>(&tr::decoder::DecoderState::getGatheredIds, py::const_))
.def("get_gathered_ids",
py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getGatheredIds, py::const_),
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly("parent_ids", &tr::decoder::DecoderState::getParentIds)
.def_property_readonly(
"cum_log_probs", py::overload_cast<>(&tr::decoder::DecoderState::getCumLogProbs, py::const_))
.def("get_cum_log_probs",
py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getCumLogProbs, py::const_),
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly("log_probs", py::overload_cast<>(&tr::decoder::DecoderState::getLogProbs, py::const_))
.def("get_log_probs", py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getLogProbs, py::const_),
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly("next_draft_tokens", &tr::decoder::DecoderState::getNextDraftTokens)
.def_property_readonly("prev_draft_tokens_lengths", &tr::decoder::DecoderState::getPrevDraftTokensLengths)
.def_property_readonly("next_draft_tokens_lengths", &tr::decoder::DecoderState::getNextDraftTokensLengths)
.def_property_readonly("accepted_lengths_cum_sum", &tr::decoder::DecoderState::getAcceptedLengthsCumSum)
.def_property_readonly("accepted_packed_paths", &tr::decoder::DecoderState::getAcceptedPackedPaths)
.def_property_readonly("max_beam_width", &tr::decoder::DecoderState::getMaxBeamWidth)
.def_property_readonly("max_sequence_length", &tr::decoder::DecoderState::getMaxSequenceLength)
.def_property_readonly("max_decoding_decoder_tokens", &tr::decoder::DecoderState::getMaxDecodingDecoderTokens)
.def_property_readonly("max_decoding_engine_tokens", &tr::decoder::DecoderState::getMaxDecodingEngineTokens)
.def_property_readonly("num_decoding_engine_tokens",
py::overload_cast<>(&tr::decoder::DecoderState::getNumDecodingEngineTokens, py::const_))
.def("get_num_decoding_engine_tokens",
py::overload_cast<tr::SizeType32>(&tr::decoder::DecoderState::getNumDecodingEngineTokens, py::const_),
py::arg("batch_idx"), py::call_guard<py::gil_scoped_release>())
.def("set_num_decoding_engine_tokens", &tr::decoder::DecoderState::setNumDecodingEngineTokens,
py::arg("batch_idx"), py::arg("num_tokens"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode)
.def_property("generation_steps", &tr::decoder::DecoderState::getGenerationSteps,
&tr::decoder::DecoderState::setGenerationSteps);
py::class_<tr::GptDecoderBatched>(m, "GptDecoderBatched")
.def(py::init<tr::GptDecoderBatched::CudaStreamPtr>(), py::arg("stream"),
py::call_guard<py::gil_scoped_release>())
.def("setup", &tr::GptDecoderBatched::setup, py::arg("mode"), py::arg("max_num_sequences"),
py::arg("max_beam_width"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config"),
py::call_guard<py::gil_scoped_release>())
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, py::arg("decoder_state"), py::arg("input"),
py::call_guard<py::gil_scoped_release>())
.def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, py::return_value_policy::reference)
.def("finalize", &tr::GptDecoderBatched::finalize, py::arg("decoder_state"), py::arg("batch_idx"),
py::arg("sampling_config"), py::arg("streaming"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"decoder_stream",
[](tr::GptDecoderBatched& self) -> tr::CudaStream const& { return *self.getDecoderStream(); },
py::return_value_policy::reference);
m.def(
"lamport_initialize_all",
[](intptr_t buffer_0, intptr_t buffer_1, intptr_t buffer_2, size_t size)
{
tr::lamportInitializeAll(reinterpret_cast<void*>(buffer_0), reinterpret_cast<void*>(buffer_1),
reinterpret_cast<void*>(buffer_2), size);
},
"Lamport initialize all buffers", py::call_guard<py::gil_scoped_release>());
m.def(
"lamport_initialize",
[](intptr_t buffer, size_t size)
{ tensorrt_llm::kernels::ar_fusion::lamport_initialize(reinterpret_cast<void*>(buffer), size, 0); },
"Lmaport initialize buffer", py::call_guard<py::gil_scoped_release>());
m.def(
"delay_kernel",
[](int64_t delay_micro_secs, py::object py_stream)
{
// Get the raw stream handle from PyTorch stream object
auto stream_ptr = py_stream.attr("cuda_stream").cast<int64_t>();
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
py::gil_scoped_release release;
tensorrt_llm::kernels::invokeDelayStreamKernel(delay_micro_secs, stream);
},
"Delay kernel launch on the default stream");
m.def(
"max_workspace_size_lowprecision",
[](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); },
"Calculate the maximum workspace size needed for low precision all-reduce operations",
py::call_guard<py::gil_scoped_release>());
py::enum_<tr::CudaVirtualMemoryAllocator::RestoreMode>(m, "CudaVirtualMemoryAllocatorRestoreMode")
.value("NONE", tr::CudaVirtualMemoryAllocator::RestoreMode::NONE)
.value("CPU", tr::CudaVirtualMemoryAllocator::RestoreMode::CPU)
.value("PINNED", tr::CudaVirtualMemoryAllocator::RestoreMode::PINNED)
.value("MEMSET", tr::CudaVirtualMemoryAllocator::RestoreMode::MEMSET);
m.def("get_virtual_memory_manager", &tr::getVirtualMemoryManager, "Get the virtual memory manager",
py::return_value_policy::reference);
m.def(
"set_virtual_memory_allocator",
[](std::string const& tag, tr::CudaVirtualMemoryAllocator::RestoreMode mode, uintptr_t stream)
{
static_assert(sizeof(uintptr_t) == sizeof(cudaStream_t));
tr::setVirtualMemoryAllocator(tag, mode,
std::make_shared<tr::CudaStream>(
reinterpret_cast<cudaStream_t>(stream), tensorrt_llm::common::getDevice(), false));
},
"Set the virtual memory allocator and start allocating virtual memory for CUDA allocations",
py::call_guard<py::gil_scoped_release>());
m.def("clear_virtual_memory_allocator", &tr::clearVirtualMemoryAllocator,
"Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations",
py::call_guard<py::gil_scoped_release>());
py::class_<tensorrt_llm::runtime::McastGPUBuffer>(m, "McastGPUBuffer")
.def(py::init<size_t, uint32_t, uint32_t, uint32_t, bool, int64_t>(), py::arg("buf_size"),
py::arg("group_size"), py::arg("group_rank"), py::arg("device_idx"), py::arg("mn_nvlink"),
py::arg("mpi_comm_fortran_handle"), py::call_guard<py::gil_scoped_release>())
.def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer,
py::call_guard<py::gil_scoped_release>())
.def("get_mc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getMCBuffer,
py::call_guard<py::gil_scoped_release>());
py::enum_<tensorrt_llm::kernels::AllReduceFusionOp>(m, "AllReduceFusionOp")
.value("NONE", tensorrt_llm::kernels::AllReduceFusionOp::NONE)
.value("RESIDUAL_RMS_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM)
.value("LAST_PROCESS_FOR_UB", tensorrt_llm::kernels::AllReduceFusionOp::LAST_PROCESS_FOR_UB)
.value("RESIDUAL_RMS_PREPOST_NORM", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_PREPOST_NORM)
.value("RESIDUAL_RMS_NORM_QUANT_FP8", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8)
.value("RESIDUAL_RMS_NORM_QUANT_NVFP4", tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4)
.value("RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4",
tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4)
.value("RESIDUAL_RMS_NORM_OUT_QUANT_FP8",
tensorrt_llm::kernels::AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8);
py::enum_<tensorrt_llm::kernels::AllReduceStrategyType>(m, "AllReduceStrategy")
.value("NCCL", tensorrt_llm::kernels::AllReduceStrategyType::NCCL)
.value("MIN_LATENCY", tensorrt_llm::kernels::AllReduceStrategyType::MIN_LATENCY)
.value("AUTO", tensorrt_llm::kernels::AllReduceStrategyType::AUTO)
.value("UB", tensorrt_llm::kernels::AllReduceStrategyType::UB)
.value("ONESHOT", tensorrt_llm::kernels::AllReduceStrategyType::ONESHOT)
.value("TWOSHOT", tensorrt_llm::kernels::AllReduceStrategyType::TWOSHOT)
.value("NCCL_SYMMETRIC", tensorrt_llm::kernels::AllReduceStrategyType::NCCL_SYMMETRIC);
// Initialize MoeLoadBalancer bindings
initMoeBindings(m);
// Initialize HostFunc bindings
initHostFuncBindings(m);
}
void initBindingsEarly(py::module_& m)
{
py::classh<tr::BufferManager>(m, "BufferManager")
.def(py::init<tr::BufferManager::CudaStreamPtr, bool>(), py::arg("stream"), py::arg("trim_pool") = false,
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("stream", &tr::BufferManager::getStream);
py::class_<tr::SpeculativeDecodingMode>(m, "SpeculativeDecodingMode")
.def(py::init<tr::SpeculativeDecodingMode::UnderlyingType>(), py::arg("state"))
.def_static("NoneType", &tr::SpeculativeDecodingMode::None)
.def_static("DraftTokensExternal", &tr::SpeculativeDecodingMode::DraftTokensExternal)
.def_static("Medusa", &tr::SpeculativeDecodingMode::Medusa)
.def_static("Eagle", &tr::SpeculativeDecodingMode::Eagle)
.def_static("LookaheadDecoding", &tr::SpeculativeDecodingMode::LookaheadDecoding)
.def_static("ExplicitDraftTokens", &tr::SpeculativeDecodingMode::ExplicitDraftTokens)
.def_property_readonly("is_none", &tr::SpeculativeDecodingMode::isNone)
.def_property_readonly("is_draft_tokens_external", &tr::SpeculativeDecodingMode::isDraftTokensExternal)
.def_property_readonly("is_medusa", &tr::SpeculativeDecodingMode::isMedusa)
.def_property_readonly("is_eagle", &tr::SpeculativeDecodingMode::isEagle)
.def_property_readonly("is_lookahead_decoding", &tr::SpeculativeDecodingMode::isLookaheadDecoding)
.def_property_readonly("is_explicit_draft_tokens", &tr::SpeculativeDecodingMode::isExplicitDraftTokens)
.def_property_readonly("updates_position_ids", &tr::SpeculativeDecodingMode::updatesPositionIds)
.def_property_readonly("requires_attention_mask", &tr::SpeculativeDecodingMode::requiresAttentionMask)
.def_property_readonly("predicts_draft_tokens", &tr::SpeculativeDecodingMode::predictsDraftTokens)
.def_property_readonly("needs_kv_cache_rewind", &tr::SpeculativeDecodingMode::needsKVCacheRewind)
.def_property_readonly("variable_draft_length", &tr::SpeculativeDecodingMode::variableDraftLength)
.def_property_readonly("has_draft_logits", &tr::SpeculativeDecodingMode::hasDraftLogits)
.def_property_readonly("needs_decoder_prologue", &tr::SpeculativeDecodingMode::needsDecoderPrologue);
}
} // namespace tensorrt_llm::pybind::runtime

View File

@ -1,31 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace tensorrt_llm::pybind::runtime
{
void initBindings(py::module_& m);
void initBindingsEarly(py::module_& m);
} // namespace tensorrt_llm::pybind::runtime

View File

@ -1,120 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "hostfunc.h"
#include "tensorrt_llm/common/logger.h"
#include <cuda_runtime.h>
#include <memory>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace tensorrt_llm::pybind::runtime
{
struct HostFuncUserData
{
bool freeUserData;
py::function pyHostFunc;
py::tuple pyArgs;
py::dict pyKwargs;
HostFuncUserData(bool freeUserData, py::function func, py::tuple args, py::dict kwargs)
: freeUserData(freeUserData)
, pyHostFunc(std::move(func))
, pyArgs(std::move(args))
, pyKwargs(std::move(kwargs))
{
}
};
static void cudaHostFuncTrampoline(void* userData)
{
// Acquire the GIL since we are calling Python code from a CUDA stream.
py::gil_scoped_acquire gil;
auto hostFuncUserData = std::unique_ptr<HostFuncUserData>(static_cast<HostFuncUserData*>(userData));
try
{
hostFuncUserData->pyHostFunc(*hostFuncUserData->pyArgs, **hostFuncUserData->pyKwargs);
}
catch (py::error_already_set& e)
{
e.restore();
PyErr_Print();
}
if (hostFuncUserData->freeUserData)
{
// If freeUserData is true, keep the ownership of the user data.
TLLM_LOG_DEBUG("Automatically freeing hostfunc user data %p", hostFuncUserData.get());
}
else
{
// If freeUserData is false, release the ownership of the user data.
hostFuncUserData.release();
}
}
std::optional<uintptr_t> launchHostFunc(
uintptr_t streamPtr, bool freeUserData, py::function pyHostFunc, py::args pyArgs, py::kwargs pyKwargs)
{
auto const stream = reinterpret_cast<cudaStream_t>(streamPtr);
py::gil_scoped_acquire gil;
auto hostFuncUserData
= std::make_unique<HostFuncUserData>(freeUserData, pyHostFunc, py::tuple(pyArgs), py::dict(pyKwargs));
py::gil_scoped_release release;
cudaError_t err = cudaLaunchHostFunc(stream, cudaHostFuncTrampoline, hostFuncUserData.get());
if (err != cudaSuccess)
{
throw std::runtime_error("Failed to launch host function.");
}
// Release the ownership of the user data.
// If freeUserData is true, the user data will be freed by cudaHostFuncTrampoline.
// If freeUserData is false, the user data should be freed by freeHostFuncUserData.
auto userDataPtr = reinterpret_cast<uintptr_t>(hostFuncUserData.release());
return freeUserData ? std::nullopt : std::make_optional(userDataPtr);
}
void freeHostFuncUserData(uintptr_t userDataPtr)
{
// Acquire the GIL to safely release the Python objects.
py::gil_scoped_acquire gil;
// Create a unique_ptr to take over the ownership of the user data;
// the user data is released when the unique_ptr is destroyed.
auto hostFuncUserData = std::unique_ptr<HostFuncUserData>(reinterpret_cast<HostFuncUserData*>(userDataPtr));
TLLM_LOG_DEBUG("Manually freeing hostfunc user data %p", hostFuncUserData.get());
}
void initHostFuncBindings(pybind11::module_& m)
{
m.def("launch_hostfunc", &launchHostFunc, "Launch a Python host function to a CUDA stream",
py::call_guard<py::gil_scoped_release>());
m.def("free_hostfunc_user_data", &freeHostFuncUserData, "Free the user data for the Python host function",
py::call_guard<py::gil_scoped_release>());
}
} // namespace tensorrt_llm::pybind::runtime

View File

@ -1,27 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::runtime
{
void initHostFuncBindings(pybind11::module_& m);
} // namespace tensorrt_llm::pybind::runtime

View File

@ -1,138 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "moeBindings.h"
#include "tensorrt_llm/runtime/moeLoadBalancer/hostAccessibleDeviceAllocator.h"
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <vector>
namespace py = pybind11;
namespace tr = tensorrt_llm::runtime;
namespace tk = tensorrt_llm::kernels;
namespace tensorrt_llm::pybind::runtime
{
void pyDoReplication(tk::MoeLoadBalanceMetaInfo const& metaInfo, std::vector<float>& expertLoadFactor,
tr::MoePlacementCpuInfo* cpuPlacement)
{
TLLM_CHECK_WITH_INFO(
metaInfo.expertCount == expertLoadFactor.size(), "expert_count and expert_load_factor size mismatch");
tr::doReplication(metaInfo, expertLoadFactor.data(), cpuPlacement);
};
void pyDoPlacement(tk::MoeLoadBalanceMetaInfo const& metaInfo, std::vector<float>& expertLoadFactor,
tr::MoePlacementCpuInfo* cpuPlacement)
{
TLLM_CHECK_WITH_INFO(
metaInfo.expertCount == expertLoadFactor.size(), "expert_count and expert_load_factor size mismatch");
tr::doPlacement(metaInfo, expertLoadFactor.data(), cpuPlacement);
};
void initMoeBindings(pybind11::module_& m)
{
// Bind MoeWeight struct
py::class_<tr::MoeWeight>(m, "MoeWeight")
.def(py::init<>())
.def_property("weight_ptr", &tr::MoeWeight::getWeightPtr, &tr::MoeWeight::setWeightPtr)
.def_readwrite("height", &tr::MoeWeight::mHeight)
.def_readwrite("width", &tr::MoeWeight::mWidth)
.def_readwrite("pitch", &tr::MoeWeight::mPitch)
.def("__repr__",
[](tr::MoeWeight const& self)
{
return "<MoeWeight ptr=" + std::to_string(self.getWeightPtr())
+ " height=" + std::to_string(self.mHeight) + " width=" + std::to_string(self.mWidth)
+ " pitch=" + std::to_string(self.mPitch) + ">";
});
// Bind MoeLoadBalanceMetaInfo struct
py::class_<tk::MoeLoadBalanceMetaInfo>(m, "MoeLoadBalanceMetaInfo")
.def(py::init<int, int, int, int, int>(), py::arg("expert_count"), py::arg("top_k"), py::arg("ep_rank"),
py::arg("ep_size"), py::arg("slot_count_per_rank"))
.def_readwrite("expert_count", &tk::MoeLoadBalanceMetaInfo::expertCount)
.def_readwrite("top_k", &tk::MoeLoadBalanceMetaInfo::topK)
.def_readwrite("ep_rank", &tk::MoeLoadBalanceMetaInfo::epRank)
.def_readwrite("ep_size", &tk::MoeLoadBalanceMetaInfo::epSize)
.def_readwrite("slot_count_per_rank", &tk::MoeLoadBalanceMetaInfo::slotCountPerRank);
// Bind MoePlacementCpuInfo struct
py::class_<tr::MoePlacementCpuInfo>(m, "MoePlacementCpuInfo")
.def(py::init<>())
.def_readwrite("expert_replica_count", &tr::MoePlacementCpuInfo::expertReplicaCount)
.def_readwrite("rank_expert_ids", &tr::MoePlacementCpuInfo::rankExpertIds);
// Bind SingleLayerMoeLoadBalancer class
py::class_<tr::SingleLayerMoeLoadBalancer, std::shared_ptr<tr::SingleLayerMoeLoadBalancer>>(
m, "SingleLayerMoeLoadBalancer")
.def("add_single_weight_slot", &tr::SingleLayerMoeLoadBalancer::addSingleWeightSlot, py::arg("slot_id"),
py::arg("name"), py::arg("weight_slot"), "Add a single weight slot for a specific slot ID",
py::call_guard<py::gil_scoped_release>())
.def("add_single_host_weight", &tr::SingleLayerMoeLoadBalancer::addSingleHostWeight, py::arg("expert_id"),
py::arg("name"), py::arg("host_weight"), "Add a single host weight for a specific expert ID",
py::call_guard<py::gil_scoped_release>())
.def("set_initial_weight_assignments", &tr::SingleLayerMoeLoadBalancer::setInitialWeightAssignments,
py::arg("initial_weight_assignments"), "Set initial weight assignments for each slot",
py::call_guard<py::gil_scoped_release>())
.def("get_pointer", &tr::SingleLayerMoeLoadBalancer::getSelfPtr,
"Get the pointer of the SingleLayerMoeLoadBalancer", py::call_guard<py::gil_scoped_release>())
.def("get_layer_id", &tr::SingleLayerMoeLoadBalancer::getLayerId,
"Get the layer id of the SingleLayerMoeLoadBalancer", py::call_guard<py::gil_scoped_release>())
.def("get_old_rank_expert_ids", &tr::SingleLayerMoeLoadBalancer::getOldRankExpertIds,
"Get the old rank expert ids of the SingleLayerMoeLoadBalancer", py::call_guard<py::gil_scoped_release>());
// Bind MoeLoadBalancer class
py::class_<tr::MoeLoadBalancer>(m, "MoeLoadBalancer")
.def(py::init<int, int, int>(), py::arg("ep_rank"), py::arg("ep_size"), py::arg("layer_updates_per_iter"),
"Initialize the MoeLoadBalancer with the specified expert parallel rank, size, and update frequency",
py::call_guard<py::gil_scoped_release>())
.def("set_use_gpu_memcpy", &tr::MoeLoadBalancer::setUseGpuMemcpy, py::arg("use_gpu_memcpy"),
"Set whether to use GPU memcpy for weight updates", py::call_guard<py::gil_scoped_release>())
.def("add_layer", &tr::MoeLoadBalancer::AddLayer, py::arg("expert_count"), py::arg("top_k"),
py::arg("slot_count_per_rank"), "Add a new MOE layer to the load balancer",
py::call_guard<py::gil_scoped_release>())
.def("finalize_model", &tr::MoeLoadBalancer::finalizeModel,
"Finalize the model structure, must be called after all layers are added",
py::call_guard<py::gil_scoped_release>())
.def("set_warm_up_iter_count", &tr::MoeLoadBalancer::setWarmUpIterCount, py::arg("iter_count"),
"Set the number of warm-up iterations", py::call_guard<py::gil_scoped_release>())
.def("start_iter", &tr::MoeLoadBalancer::startIter, py::arg("iter_id"), py::arg("enable_statistic"),
py::arg("enable_update_weights"), "Start a new iteration with the given ID and settings",
py::call_guard<py::gil_scoped_release>())
.def("end_iter", &tr::MoeLoadBalancer::endIter, py::arg("iter_id"), "End the iteration with the given ID",
py::call_guard<py::gil_scoped_release>())
.def("shutdown", &tr::MoeLoadBalancer::shutdown, "Shutdown the load balancer and clean up resources",
py::call_guard<py::gil_scoped_release>());
m.def("is_host_accessible_device_memory_supported", &tr::HostAccessibleDeviceAllocator::isSupported,
"If current system support host accessible device memory");
// Bind do_replication function for testing
m.def("do_replication", &pyDoReplication, py::arg("meta_info"), py::arg("expert_load_factor"),
py::arg("cpu_placement"), "Do replication");
// Bind do_placement function for testing
m.def("do_placement", &pyDoPlacement, py::arg("meta_info"), py::arg("expert_load_factor"), py::arg("cpu_placement"),
"Do placement");
}
} // namespace tensorrt_llm::pybind::runtime

View File

@ -1,27 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::runtime
{
void initMoeBindings(pybind11::module_& m);
} // namespace tensorrt_llm::pybind::runtime

View File

@ -1,85 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "modelSpecBinding.h"
#include "tensorrt_llm/testing/modelSpec.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
using tensorrt_llm::testing::ModelSpec;
using tensorrt_llm::testing::KVCacheType;
using tensorrt_llm::testing::QuantMethod;
using tensorrt_llm::testing::OutputContentType;
namespace tensorrt_llm::pybind::testing
{
void initBindings(py::module_& m)
{
py::enum_<QuantMethod>(m, "QuantMethod", py::arithmetic(), "Quantization Method")
.value("NONE", QuantMethod::kNONE, "No Quantization")
.value("SMOOTH_QUANT", QuantMethod::kSMOOTH_QUANT, "Smooth Quantization");
py::enum_<OutputContentType>(m, "OutputContentType", py::arithmetic(), "Output Content Type")
.value("NONE", OutputContentType::kNONE, "No Output Content")
.value("CONTEXT_LOGITS", OutputContentType::kCONTEXT_LOGITS, "Context Logits")
.value("GENERATION_LOGITS", OutputContentType::kGENERATION_LOGITS, "Generation Logits")
.value("LOG_PROBS", OutputContentType::kLOG_PROBS, "Log Probs")
.value("CUM_LOG_PROBS", OutputContentType::kCUM_LOG_PROBS, "Cumulative Log");
py::class_<ModelSpec>(m, "ModelSpec")
.def(py::init<std::string const&, nvinfer1::DataType>())
.def("use_gpt_plugin", &ModelSpec::useGptAttentionPlugin)
.def("use_packed_input", &ModelSpec::usePackedInput)
.def("set_kv_cache_type", &ModelSpec::setKVCacheType)
.def("use_decoder_per_request", &ModelSpec::useDecoderPerRequest)
.def("use_tensor_parallelism", &ModelSpec::useTensorParallelism)
.def("use_pipeline_parallelism", &ModelSpec::usePipelineParallelism)
.def("use_context_parallelism", &ModelSpec::useContextParallelism)
.def("set_draft_tokens", &ModelSpec::setDraftTokens)
.def("use_accept_by_logits", &ModelSpec::useAcceptByLogits)
.def("use_mamba_plugin", &ModelSpec::useMambaPlugin)
.def("gather_logits", &ModelSpec::gatherLogits)
.def("replace_logits", &ModelSpec::replaceLogits)
.def("return_log_probs", &ModelSpec::returnLogProbs)
.def("smoke_test", &ModelSpec::smokeTest)
.def("use_medusa", &ModelSpec::useMedusa)
.def("use_eagle", &ModelSpec::useEagle)
.def("use_lookahead_decoding", &ModelSpec::useLookaheadDecoding)
.def("use_explicit_draft_tokens_decoding", &ModelSpec::useExplicitDraftTokensDecoding)
.def("use_draft_tokens_external_decoding", &ModelSpec::useDraftTokensExternalDecoding)
.def("use_logits", &ModelSpec::useLogits)
.def("use_multiple_profiles", &ModelSpec::useMultipleProfiles)
.def("set_max_input_length", &ModelSpec::setMaxInputLength)
.def("set_max_output_length", &ModelSpec::setMaxOutputLength)
.def("set_quant_method", &ModelSpec::setQuantMethod)
.def("use_lora_plugin", &ModelSpec::useLoraPlugin)
.def("get_input_file", &ModelSpec::getInputFile)
.def("get_model_path", &ModelSpec::getModelPath)
.def("get_results_file", &ModelSpec::getResultsFile)
.def("get_generation_logits_file", &ModelSpec::getGenerationLogitsFile)
.def("get_context_logits_file", &ModelSpec::getContextLogitsFile)
.def("get_cum_log_probs_file", &ModelSpec::getCumLogProbsFile)
.def("get_log_probs_file", &ModelSpec::getLogProbsFile)
.def("enable_context_fmha_fp32_acc", &ModelSpec::enableContextFMHAFp32Acc)
.def("get_enable_context_fmha_fp32_acc", &ModelSpec::getEnableContextFMHAFp32Acc)
.def("__copy__", [](ModelSpec const& self) { return ModelSpec(self); });
}
} // namespace tensorrt_llm::pybind::testing

View File

@ -1,30 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace tensorrt_llm::pybind::testing
{
void initBindings(py::module_& m);
} // namespace tensorrt_llm::pybind::testing

View File

@ -1,83 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <tensorrt_llm/kernels/helixAllToAll.h>
#include <tensorrt_llm/thop/attentionOp.h>
#include <tensorrt_llm/thop/moeAlltoAllMeta.h>
#include <torch/extension.h>
namespace py = pybind11;
namespace tensorrt_llm::pybind::thop
{
void initBindings(pybind11::module_& m)
{
// Export MoE A2A constants
for (auto const& kv : torch_ext::moe_comm::getMoeA2AMetaInfoIndexPairs())
{
m.attr(kv.first) = py::int_(kv.second);
}
m.def("attention", &torch_ext::attention,
// Parameters with default values using std::nullopt for optional arguments
py::arg("q"), py::arg("k") = std::nullopt, py::arg("v") = std::nullopt, py::arg("output"),
py::arg("output_sf") = std::nullopt, py::arg("workspace_") = std::nullopt, py::arg("sequence_length"),
py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"), py::arg("context_lengths"),
py::arg("host_context_lengths"), py::arg("host_request_types"),
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_pool_pointers") = std::nullopt,
py::arg("host_kv_cache_pool_mapping") = std::nullopt, py::arg("cache_indirection") = std::nullopt,
py::arg("kv_scale_orig_quant") = std::nullopt, py::arg("kv_scale_quant_orig") = std::nullopt,
py::arg("out_scale") = std::nullopt, py::arg("rotary_inv_freq") = std::nullopt,
py::arg("rotary_cos_sin") = std::nullopt, py::arg("latent_cache") = std::nullopt,
py::arg("q_pe") = std::nullopt, py::arg("block_ids_per_seq") = std::nullopt,
py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"), py::arg("update_kv_cache"),
py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"), py::arg("num_kv_heads"),
py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt, py::arg("max_num_requests"),
py::arg("max_context_length"), py::arg("attention_window_size"), py::arg("sink_token_length"),
py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"), py::arg("q_scaling"),
py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"), py::arg("rotary_embedding_base"),
py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"),
py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"),
py::arg("chunked_prefill_buffer_batch_size") = std::nullopt, py::arg("q_lora_rank") = std::nullopt,
py::arg("kv_lora_rank") = std::nullopt, py::arg("qk_nope_head_dim") = std::nullopt,
py::arg("qk_rope_head_dim") = std::nullopt, py::arg("v_head_dim") = std::nullopt,
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
py::arg("mla_tensor_params"), py::arg("attention_chunk_size") = std::nullopt,
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
py::arg("spec_decoding_tensor_params"), py::arg("sparse_kv_indices") = std::nullopt,
py::arg("sparse_kv_offsets") = std::nullopt, py::arg("sparse_attn_indices") = std::nullopt,
py::arg("sparse_attn_offsets") = std::nullopt, py::arg("sparse_attn_indices_block_size"),
py::arg("sparse_mla_topk") = std::nullopt,
py::arg("skip_softmax_threshold_scale_factor_prefill") = std::nullopt,
py::arg("skip_softmax_threshold_scale_factor_decode") = std::nullopt,
py::arg("skip_softmax_stat") = std::nullopt, py::arg("cu_q_seqlens") = std::nullopt,
py::arg("cu_kv_seqlens") = std::nullopt, py::arg("fmha_scheduler_counter") = std::nullopt,
py::arg("mla_bmm1_scale") = std::nullopt, py::arg("mla_bmm2_scale") = std::nullopt,
py::arg("quant_q_buffer") = std::nullopt, "Multi-head attention operation",
py::call_guard<py::gil_scoped_release>());
m.def(
"get_helix_workspace_size_per_rank",
[](int cp_size) { return tensorrt_llm::kernels::computeHelixWorkspaceSizePerRank(cp_size); },
py::arg("cp_size"), "Get helix all-to-all workspace size per rank in bytes");
}
} // namespace tensorrt_llm::pybind::thop

View File

@ -1,27 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::thop
{
void initBindings(pybind11::module_& m);
} // namespace tensorrt_llm::pybind::thop

View File

@ -1,55 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/kernels/userbuffers/ub_interface.h"
#include "tensorrt_llm/kernels/userbuffers/userbuffersManager.h"
namespace py = pybind11;
namespace tub = tensorrt_llm::runtime::ub;
TRTLLM_NAMESPACE_BEGIN
namespace kernels::userbuffers
{
void UserBufferBindings::initBindings(pybind11::module_& m)
{
py::class_<tub::UBBuffer>(m, "UBBuffer")
.def_readonly("size", &tub::UBBuffer::size)
.def_property_readonly("addr", [](tub::UBBuffer& self) { return reinterpret_cast<intptr_t>(self.addr); })
.def_readonly("handle", &tub::UBBuffer::handle)
.def("invalid", &tub::UBBuffer::invalid, py::call_guard<py::gil_scoped_release>());
m.def(
"ub_initialize", [](int tp_size) { tub::ub_initialize(tp_size); }, py::call_guard<py::gil_scoped_release>());
m.def("ub_is_initialized", &tub::ub_is_initialized, py::call_guard<py::gil_scoped_release>());
m.def(
"ub_allocate", [](size_t bytes) { return tub::ub_allocate(bytes); }, py::call_guard<py::gil_scoped_release>());
m.def(
"ub_deallocate", [](intptr_t addr) { return tub::ub_deallocate(reinterpret_cast<void*>(addr)); },
py::call_guard<py::gil_scoped_release>());
m.def("ub_get", &tub::ub_get, py::call_guard<py::gil_scoped_release>());
m.def("ub_supported", &tub::ub_supported, py::call_guard<py::gil_scoped_release>());
m.def("initialize_userbuffers_manager", &tub::initialize_userbuffers_manager,
py::call_guard<py::gil_scoped_release>());
}
} // namespace kernels::userbuffers
TRTLLM_NAMESPACE_END

View File

@ -1,35 +0,0 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/pybind/common/customCasters.h"
#include <pybind11/pybind11.h>
TRTLLM_NAMESPACE_BEGIN
namespace kernels::userbuffers
{
class UserBufferBindings
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace kernels::userbuffers
TRTLLM_NAMESPACE_END