/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include "tensorrt_llm/batch_manager/kvCacheConfig.h" #include "tensorrt_llm/batch_manager/trtGptModel.h" #include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/pybind/batch_manager/algorithms.h" #include "tensorrt_llm/pybind/batch_manager/bindings.h" #include "tensorrt_llm/pybind/batch_manager/buffers.h" #include "tensorrt_llm/pybind/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h" #include "tensorrt_llm/pybind/batch_manager/llmRequest.h" #include "tensorrt_llm/pybind/executor/bindings.h" #include "tensorrt_llm/pybind/runtime/bindings.h" #include "tensorrt_llm/pybind/userbuffers/bindings.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/gptJsonConfig.h" #include "tensorrt_llm/runtime/ipcNvlsMemory.h" #include "tensorrt_llm/runtime/ipcUtils.h" #include "tensorrt_llm/runtime/memoryCounters.h" #include "tensorrt_llm/runtime/samplingConfig.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" namespace py = pybind11; namespace tb = tensorrt_llm::batch_manager; namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager; namespace tpb = tensorrt_llm::pybind::batch_manager; namespace tc = tensorrt_llm::common; namespace tr = tensorrt_llm::runtime; namespace tle = tensorrt_llm::executor; using SizeType32 = tr::SizeType32; using TokenIdType = tr::TokenIdType; template using OptVec = std::optional>; #if not defined(TRTLLM_PYBIND_MODULE) #error "TRTLLM_PYBIND_MODULE must be defined" #endif namespace { tr::SamplingConfig makeSamplingConfig(std::vector const& configs) { return tr::SamplingConfig(configs); } } // namespace PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) { m.doc() = "TensorRT-LLM Python bindings for C++ runtime"; // Create MpiComm binding first since it's used in the executor bindings py::classh(m, "MpiComm") .def_static("rank", []() { auto& session = tensorrt_llm::mpi::MpiComm::session(); return session.tensorrt_llm::mpi::MpiComm::getRank(); }) .def_static("size", []() { auto& session = tensorrt_llm::mpi::MpiComm::session(); return session.tensorrt_llm::mpi::MpiComm::getSize(); }) .def_static("local_size", []() { auto& session = tensorrt_llm::mpi::MpiComm::localSession(); return session.tensorrt_llm::mpi::MpiComm::getSize(); }) .def_static("local_init", []() { tensorrt_llm::mpi::MpiComm::localSession(); }) .def_static("set_raw_mpi_session_by_fortran_handle", [](int64_t fortran_handle) { tensorrt_llm::mpi::MpiComm::setRawSessionByFortran(fortran_handle); }) .def_static("split", [](size_t color, size_t rank) { auto& world = tensorrt_llm::mpi::MpiComm::world(); tensorrt_llm::mpi::MpiComm::setSession(world.split(color, rank)); }); // Create submodule for executor bindings. py::module_ executor_submodule = m.def_submodule("executor", "Executor bindings"); tensorrt_llm::pybind::executor::initBindings(executor_submodule); auto buildInfo = m.def_submodule("BuildInfo"); buildInfo.attr("ENABLE_MULTI_DEVICE") = py::int_(ENABLE_MULTI_DEVICE); auto kvCacheConfigGetState = [](tbk::KvCacheConfig const& config) { return py::make_tuple(config.maxTokens, config.maxAttentionWindowVec, config.sinkTokenLength, config.freeGpuMemoryFraction, config.enableBlockReuse, config.useUvm, config.hostCacheSize, config.onboardBlocks, config.crossKvCacheFraction, config.secondaryOffloadMinPriority, config.eventBufferMaxSize, config.enablePartialReuse, config.copyOnPartialReuse); }; auto kvCacheConfigSetState = [](py::tuple t) { return tbk::KvCacheConfig(t[0].cast>(), t[1].cast>>(), t[2].cast>(), t[3].cast>(), t[4].cast(), t[5].cast(), t[6].cast>(), t[7].cast(), t[8].cast>(), t[9].cast>(), t[10].cast(), t[11].cast(), t[12].cast()); }; py::class_(m, "KvCacheConfig") .def(py::init, std::optional>, std::optional, std::optional, bool, bool, std::optional, bool, std::optional, std::optional, size_t, bool, bool>(), py::arg("max_tokens") = py::none(), py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none(), py::arg("free_gpu_memory_fraction") = py::none(), py::arg("enable_block_reuse") = false, py::arg("use_uvm") = false, py::arg("host_cache_size") = py::none(), py::arg("onboard_blocks") = true, py::arg("cross_kv_cache_fraction") = py::none(), py::arg("secondary_offload_min_priority") = py::none(), py::arg("event_buffer_max_size") = 0, py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true) .def_readwrite("max_tokens", &tbk::KvCacheConfig::maxTokens) .def_readwrite("max_attention_window", &tbk::KvCacheConfig::maxAttentionWindowVec) .def_readwrite("sink_token_length", &tbk::KvCacheConfig::sinkTokenLength) .def_readwrite("free_gpu_memory_fraction", &tbk::KvCacheConfig::freeGpuMemoryFraction) .def_readwrite("enable_block_reuse", &tbk::KvCacheConfig::enableBlockReuse) .def_readwrite("use_uvm", &tbk::KvCacheConfig::useUvm) .def_readwrite("host_cache_size", &tbk::KvCacheConfig::hostCacheSize) .def_readwrite("onboard_blocks", &tbk::KvCacheConfig::onboardBlocks) .def_readwrite("cross_kv_cache_fraction", &tbk::KvCacheConfig::crossKvCacheFraction) .def_readwrite("secondary_offload_min_priority", &tbk::KvCacheConfig::secondaryOffloadMinPriority) .def_readwrite("event_buffer_max_size", &tbk::KvCacheConfig::eventBufferMaxSize) .def_readwrite("enable_partial_reuse", &tbk::KvCacheConfig::enablePartialReuse) .def_readwrite("copy_on_partial_reuse", &tbk::KvCacheConfig::copyOnPartialReuse) .def(py::pickle(kvCacheConfigGetState, kvCacheConfigSetState)) .def("__eq__", &tbk::KvCacheConfig::operator==); std::optional maxTokens; std::optional> maxAttentionWindowVec; std::optional sinkTokenLength; std::optional freeGpuMemoryFraction; bool enableBlockReuse; static constexpr auto kDefaultGpuMemFraction = 0.9F; bool useUvm; std::optional hostCacheSize; bool onboardBlocks; // Cross will use crossKvCacheFraction of KV Cache and self attention will use the rest. std::optional crossKvCacheFraction; // The minimum priority level to allow blocks to be offloaded to secondary memory. std::optional secondaryOffloadMinPriority; // Maximum size of the KV Cache event buffer size_t eventBufferMaxSize; bool enablePartialReuse; bool copyOnPartialReuse; py::class_(m, "PeftCacheManagerConfig") .def(py::init, std::optional, std::optional>(), py::arg("num_host_module_layer") = 0, py::arg("num_device_module_layer") = 0, py::arg("optimal_adapter_size") = 8, py::arg("max_adapter_size") = 64, py::arg("num_put_workers") = 1, py::arg("num_ensure_workers") = 1, py::arg("num_copy_streams") = 1, py::arg("max_pages_per_block_host") = 24, py::arg("max_pages_per_block_device") = 8, py::arg("device_cache_percent") = std::nullopt, py::arg("host_cache_size") = std::nullopt, py::arg("lora_prefetch_dir") = std::nullopt) .def_readwrite("num_host_module_layer", &tb::PeftCacheManagerConfig::numHostModuleLayer) .def_readwrite("num_device_module_layer", &tb::PeftCacheManagerConfig::numDeviceModuleLayer) .def_readwrite("optimal_adapter_size", &tb::PeftCacheManagerConfig::optimalAdapterSize) .def_readwrite("max_adapter_size", &tb::PeftCacheManagerConfig::maxAdapterSize) .def_readwrite("num_put_workers", &tb::PeftCacheManagerConfig::numPutWorkers) .def_readwrite("num_ensure_workers", &tb::PeftCacheManagerConfig::numEnsureWorkers) .def_readwrite("num_copy_streams", &tb::PeftCacheManagerConfig::numCopyStreams) .def_readwrite("max_pages_per_block_host", &tb::PeftCacheManagerConfig::maxPagesPerBlockHost) .def_readwrite("max_pages_per_block_device", &tb::PeftCacheManagerConfig::maxPagesPerBlockDevice) .def_readwrite("device_cache_percent", &tb::PeftCacheManagerConfig::deviceCachePercent) .def_readwrite("host_cache_size", &tb::PeftCacheManagerConfig::hostCacheSize) .def_readwrite("lora_prefetch_dir", &tb::PeftCacheManagerConfig::loraPrefetchDir); py::enum_(m, "DataType") .value("FLOAT", nvinfer1::DataType::kFLOAT) .value("HALF", nvinfer1::DataType::kHALF) .value("INT8", nvinfer1::DataType::kINT8) .value("INT32", nvinfer1::DataType::kINT32) .value("BOOL", nvinfer1::DataType::kBOOL) .value("UINT8", nvinfer1::DataType::kUINT8) .value("FP8", nvinfer1::DataType::kFP8) .value("BF16", nvinfer1::DataType::kBF16) .value("INT64", nvinfer1::DataType::kINT64) .export_values(); py::enum_(m, "GptModelVariant") .value("GPT", tr::ModelConfig::ModelVariant::kGpt) .value("GLM", tr::ModelConfig::ModelVariant::kGlm) .value("CHATGLM", tr::ModelConfig::ModelVariant::kChatGlm) .value("MAMBA", tr::ModelConfig::ModelVariant::kMamba) .value("RECURRENTGEMMA", tr::ModelConfig::ModelVariant::kRecurrentGemma); py::enum_(m, "KVCacheType") .value("CONTINUOUS", tr::ModelConfig::KVCacheType::kCONTINUOUS) .value("PAGED", tr::ModelConfig::KVCacheType::kPAGED) .value("DISABLED", tr::ModelConfig::KVCacheType::kDISABLED) .def(py::init(&tr::ModelConfig::KVCacheTypeFromString)); py::enum_(m, "LayerType") .value("ATTENTION", tr::ModelConfig::LayerType::kATTENTION) .value("RECURRENT", tr::ModelConfig::LayerType::kRECURRENT); py::enum_(m, "LoraModuleType") .value("INVALID", tr::LoraModule::ModuleType::kINVALID) .value("ATTN_QKV", tr::LoraModule::ModuleType::kATTN_QKV) .value("ATTN_Q", tr::LoraModule::ModuleType::kATTN_Q) .value("ATTN_K", tr::LoraModule::ModuleType::kATTN_K) .value("ATTN_V", tr::LoraModule::ModuleType::kATTN_V) .value("ATTN_DENSE", tr::LoraModule::ModuleType::kATTN_DENSE) .value("MLP_H_TO_4H", tr::LoraModule::ModuleType::kMLP_H_TO_4H) .value("MLP_4H_TO_H", tr::LoraModule::ModuleType::kMLP_4H_TO_H) .value("MLP_GATE", tr::LoraModule::ModuleType::kMLP_GATE) .value("CROSS_ATTN_QKV", tr::LoraModule::ModuleType::kCROSS_ATTN_QKV) .value("CROSS_ATTN_Q", tr::LoraModule::ModuleType::kCROSS_ATTN_Q) .value("CROSS_ATTN_K", tr::LoraModule::ModuleType::kCROSS_ATTN_K) .value("CROSS_ATTN_V", tr::LoraModule::ModuleType::kCROSS_ATTN_V) .value("CROSS_ATTN_DENSE", tr::LoraModule::ModuleType::kCROSS_ATTN_DENSE) .value("MOE_H_TO_4H", tr::LoraModule::ModuleType::kMOE_H_TO_4H) .value("MOE_4H_TO_H", tr::LoraModule::ModuleType::kMOE_4H_TO_H) .value("MOE_GATE", tr::LoraModule::ModuleType::kMOE_GATE) .value("MOE_ROUTER", tr::LoraModule::ModuleType::kMOE_ROUTER) .value("MLP_ROUTER", tr::LoraModule::ModuleType::kMLP_ROUTER) .value("MLP_GATE_UP", tr::LoraModule::ModuleType::kMLP_GATE_UP); py::class_(m, "LoraModule") .def(py::init(), py::arg("module_type"), py::arg("in_dim"), py::arg("out_dim"), py::arg("in_dim_first"), py::arg("out_dim_first"), py::arg("in_tp_split_dim"), py::arg("out_tp_split_dim")) .def_property_readonly("module_type", &tr::LoraModule::name) .def_property_readonly("in_dim", &tr::LoraModule::inDim) .def_property_readonly("out_dim", &tr::LoraModule::outDim) .def_property_readonly("in_dim_first", &tr::LoraModule::inDimFirst) .def_property_readonly("out_dim_first", &tr::LoraModule::outDimFirst) .def_property_readonly("in_tp_split_dim", &tr::LoraModule::inTpSplitDim) .def_property_readonly("out_tp_split_dim", &tr::LoraModule::outTpSplitDim); py::class_(m, "QuantMode") .def_static("none", &tc::QuantMode::none) .def_static("int4_weights", &tc::QuantMode::int4Weights) .def_static("int8_weights", &tc::QuantMode::int8Weights) .def_static("activations", &tc::QuantMode::activations) .def_static("per_channel_scaling", &tc::QuantMode::perChannelScaling) .def_static("per_token_scaling", &tc::QuantMode::perTokenScaling) .def_static("per_group_scaling", &tc::QuantMode::perGroupScaling) .def_static("int8_kv_cache", &tc::QuantMode::int8KvCache) .def_static("fp8_kv_cache", &tc::QuantMode::fp8KvCache) .def_static("fp8_qdq", &tc::QuantMode::fp8Qdq) .def_property_readonly("value", &tc::QuantMode::value) .def("is_set", &tc::QuantMode::isSet, py::arg("mode")) .def_property_readonly("has_int4_weights", &tc::QuantMode::hasInt4Weights) .def_property_readonly("has_int8_weights", &tc::QuantMode::hasInt8Weights) .def_property_readonly("has_activations", &tc::QuantMode::hasActivations) .def_property_readonly("has_per_channel_scaling", &tc::QuantMode::hasPerChannelScaling) .def_property_readonly("has_per_token_scaling", &tc::QuantMode::hasPerTokenScaling) .def_property_readonly("has_per_group_scaling", &tc::QuantMode::hasPerGroupScaling) .def_property_readonly("has_static_activation_scaling", &tc::QuantMode::hasStaticActivationScaling) .def_property_readonly("has_int8_kv_cache", &tc::QuantMode::hasInt8KvCache) .def_property_readonly("has_fp8_kv_cache", &tc::QuantMode::hasFp8KvCache) .def_property_readonly("has_fp8_qdq", &tc::QuantMode::hasFp8Qdq) .def_property_readonly("has_nvfp4", &tc::QuantMode::hasNvfp4) .def_property_readonly("has_kv_cache_quant", &tc::QuantMode::hasKvCacheQuant) .def_static("from_description", &tc::QuantMode::fromDescription, py::arg("quantize_weights") = false, py::arg("quantize_activations") = false, py::arg("per_token") = false, py::arg("per_channel") = false, py::arg("per_group") = false, py::arg("use_int4_weights") = false, py::arg("use_int8_kv_cache") = false, py::arg("use_fp8_kv_kache") = false, py::arg("use_fp8_qdq") = false, py::arg("use_fp8_rowwise") = false, py::arg("use_w4a8_qserve") = false, py::arg("use_nvfp4") = false, py::arg("use_fp8_block_scales") = false) .def_static("use_smooth_quant", &tc::QuantMode::useSmoothQuant, py::arg("per_token") = false, py::arg("per_channel") = false) .def_static("use_weight_only", &tc::QuantMode::useWeightOnly, py::arg("use_int4_weights") = false, py::arg("per_group") = false) .def_static("from_quant_algo", &tc::QuantMode::fromQuantAlgo, py::arg("quant_algo") = py::none(), py::arg("kv_cache_quant_algo") = py::none()) .def(py::self + py::self) .def(py::self += py::self) .def(py::self - py::self) .def(py::self -= py::self) .def(py::self == py::self) .def(py::self != py::self); py::class_(m, "ModelConfig") .def(py::init(), py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"), py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type")) .def_property_readonly("vocab_size", &tr::ModelConfig::getVocabSize) .def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size")) .def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1) .def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1, py::arg("pipeline_parallelism_rank") = 0) .def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1, py::arg("pipeline_parallelism_rank") = 0) .def("num_kv_heads", &tr::ModelConfig::getNbKvHeads, py::arg("layer_idx")) .def("set_num_kv_heads", &tr::ModelConfig::setNbKvHeads, py::arg("num_kv_heads")) .def_property_readonly("num_heads", &tr::ModelConfig::getNbHeads) .def_property_readonly("hidden_size", &tr::ModelConfig::getHiddenSize) .def_property_readonly("size_per_head", &tr::ModelConfig::getSizePerHead) .def_property_readonly("data_type", &tr::ModelConfig::getDataType) .def_property("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead) .def_property( "num_kv_heads_per_layer", &tr::ModelConfig::getNumKvHeadsPerLayer, &tr::ModelConfig::setNumKvHeadsPerLayer) .def_property("use_gpt_attention_plugin", py::overload_cast<>(&tr::ModelConfig::useGptAttentionPlugin, py::const_), py::overload_cast(&tr::ModelConfig::useGptAttentionPlugin)) .def_property("use_packed_input", py::overload_cast<>(&tr::ModelConfig::usePackedInput, py::const_), py::overload_cast(&tr::ModelConfig::usePackedInput)) .def_property("kv_cache_type", py::overload_cast<>(&tr::ModelConfig::getKVCacheType, py::const_), py::overload_cast(&tr::ModelConfig::setKVCacheType)) .def_property("tokens_per_block", &tr::ModelConfig::getTokensPerBlock, &tr::ModelConfig::setTokensPerBlock) .def_property("quant_mode", &tr::ModelConfig::getQuantMode, &tr::ModelConfig::setQuantMode) .def_property_readonly("supports_inflight_batching", &tr::ModelConfig::supportsInflightBatching) .def_property("max_batch_size", &tr::ModelConfig::getMaxBatchSize, &tr::ModelConfig::setMaxBatchSize) .def_property("max_beam_width", &tr::ModelConfig::getMaxBeamWidth, &tr::ModelConfig::setMaxBeamWidth) .def_property("max_input_len", &tr::ModelConfig::getMaxInputLen, &tr::ModelConfig::setMaxInputLen) .def_property("max_seq_len", &tr::ModelConfig::getMaxSequenceLen, &tr::ModelConfig::setMaxSequenceLen) .def_property("max_num_tokens", &tr::ModelConfig::getMaxNumTokens, &tr::ModelConfig::setMaxNumTokens) .def_property("max_prompt_embedding_table_size", &tr::ModelConfig::getMaxPromptEmbeddingTableSize, &tr::ModelConfig::setMaxPromptEmbeddingTableSize) .def_property_readonly("use_prompt_tuning", &tr::ModelConfig::usePromptTuning) .def_property_readonly("use_mrope", &tr::ModelConfig::useMrope) .def_property("use_lora_plugin", py::overload_cast<>(&tr::ModelConfig::useLoraPlugin, py::const_), py::overload_cast(&tr::ModelConfig::useLoraPlugin)) .def_property("layer_types", &tr::ModelConfig::getLayerTypes, &tr::ModelConfig::setLayerTypes) .def_property("compute_context_logits", py::overload_cast<>(&tr::ModelConfig::computeContextLogits, py::const_), py::overload_cast(&tr::ModelConfig::computeContextLogits)) .def_property("compute_generation_logits", py::overload_cast<>(&tr::ModelConfig::computeGenerationLogits, py::const_), py::overload_cast(&tr::ModelConfig::computeGenerationLogits)) .def_property("model_variant", &tr::ModelConfig::getModelVariant, &tr::ModelConfig::setModelVariant) .def_property( "use_cross_attention", &tr::ModelConfig::useCrossAttention, &tr::ModelConfig::setUseCrossAttention) .def_property("lora_modules", &tr::ModelConfig::getLoraModules, &tr::ModelConfig::setLoraModules) .def_property("max_lora_rank", &tr::ModelConfig::getMaxLoraRank, &tr::ModelConfig::setMaxLoraRank); py::class_(m, "WorldConfig") .def(py::init> const&, bool>(), py::arg("tensor_parallelism") = 1, py::arg("pipeline_parallelism") = 1, py::arg("context_parallelism") = 1, py::arg("rank") = 0, py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, py::arg("device_ids") = py::none(), py::arg("enable_attention_dp") = false) .def_property_readonly("size", &tr::WorldConfig::getSize) .def_property_readonly("tensor_parallelism", &tr::WorldConfig::getTensorParallelism) .def_property_readonly("pipeline_parallelism", &tr::WorldConfig::getPipelineParallelism) .def_property_readonly("context_parallelism", &tr::WorldConfig::getContextParallelism) .def_property_readonly("is_tensor_parallel", &tr::WorldConfig::isTensorParallel) .def_property_readonly("is_pipeline_parallel", &tr::WorldConfig::isPipelineParallel) .def_property_readonly("is_context_parallel", &tr::WorldConfig::isContextParallel) .def_property_readonly("rank", &tr::WorldConfig::getRank) .def_property_readonly("local_rank", &tr::WorldConfig::getLocalRank) .def_property_readonly("node_rank", &tr::WorldConfig::getNodeRank) .def_property_readonly("gpus_per_node", &tr::WorldConfig::getGpusPerNode) .def_property_readonly("gpus_per_group", &tr::WorldConfig::getGpusPerGroup) .def_property_readonly("device", &tr::WorldConfig::getDevice) .def_property_readonly("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank) .def_property_readonly("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank) .def_property_readonly("context_parallel_rank", &tr::WorldConfig::getContextParallelRank) .def_property_readonly("enable_attention_dp", &tr::WorldConfig::enableAttentionDP) .def_static("mpi", py::overload_cast, std::optional, std::optional, std::optional> const&, bool>(&tr::WorldConfig::mpi), py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, py::arg("tensor_parallelism") = py::none(), py::arg("pipeline_parallelism") = py::none(), py::arg("context_parallelism") = py::none(), py::arg("device_ids") = py::none(), py::arg("enable_attention_dp") = false); auto SamplingConfigGetState = [](tr::SamplingConfig const& config) -> py::tuple { return py::make_tuple(config.beamWidth, config.temperature, config.minLength, config.repetitionPenalty, config.presencePenalty, config.frequencyPenalty, config.topK, config.topP, config.randomSeed, config.topPDecay, config.topPMin, config.topPResetIds, config.beamSearchDiversityRate, config.lengthPenalty, config.earlyStopping, config.noRepeatNgramSize, config.numReturnSequences, config.minP, config.beamWidthArray); }; auto SamplingConfigSetState = [](py::tuple t) -> tr::SamplingConfig { assert(t.size() == 19); tr::SamplingConfig config; config.beamWidth = t[0].cast(); config.temperature = t[1].cast>(); config.minLength = t[2].cast>(); config.repetitionPenalty = t[3].cast>(); config.presencePenalty = t[4].cast>(); config.frequencyPenalty = t[5].cast>(); config.topK = t[6].cast>(); config.topP = t[7].cast>(); config.randomSeed = t[8].cast>(); config.topPDecay = t[9].cast>(); config.topPMin = t[10].cast>(); config.topPResetIds = t[11].cast>(); config.beamSearchDiversityRate = t[12].cast>(); config.lengthPenalty = t[13].cast>(); config.earlyStopping = t[14].cast>(); config.noRepeatNgramSize = t[15].cast>(); config.numReturnSequences = t[16].cast(); config.minP = t[17].cast>(); config.beamWidthArray = t[18].cast>>(); return config; }; py::classh(m, "SamplingConfig") .def(py::init(), py::arg("beam_width") = 1) .def(py::init>(), py::arg("executor_sample_config"), py::arg("external_draft_tokens_config") = std::nullopt) .def_readwrite("beam_width", &tr::SamplingConfig::beamWidth) .def_readwrite("temperature", &tr::SamplingConfig::temperature) .def_readwrite("min_length", &tr::SamplingConfig::minLength) .def_readwrite("repetition_penalty", &tr::SamplingConfig::repetitionPenalty) .def_readwrite("presence_penalty", &tr::SamplingConfig::presencePenalty) .def_readwrite("frequency_penalty", &tr::SamplingConfig::frequencyPenalty) .def_readwrite("top_k", &tr::SamplingConfig::topK) .def_readwrite("top_p", &tr::SamplingConfig::topP) .def_readwrite("random_seed", &tr::SamplingConfig::randomSeed) .def_readwrite("top_p_decay", &tr::SamplingConfig::topPDecay) .def_readwrite("top_p_min", &tr::SamplingConfig::topPMin) .def_readwrite("top_p_reset_ids", &tr::SamplingConfig::topPResetIds) .def_readwrite("beam_search_diversity_rate", &tr::SamplingConfig::beamSearchDiversityRate) .def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty) .def_readwrite("early_stopping", &tr::SamplingConfig::earlyStopping) .def_readwrite("no_repeat_ngram_size", &tr::SamplingConfig::noRepeatNgramSize) .def_readwrite("num_return_sequences", &tr::SamplingConfig::numReturnSequences) .def_readwrite("min_p", &tr::SamplingConfig::minP) .def_readwrite("beam_width_array", &tr::SamplingConfig::beamWidthArray) .def(py::pickle(SamplingConfigGetState, SamplingConfigSetState)) .def("__eq__", &tr::SamplingConfig::operator==); py::bind_vector>(m, "VectorSamplingConfig"); m.def("make_sampling_config", &makeSamplingConfig, py::arg("configs")); py::class_(m, "GptJsonConfig") .def(py::init>(), py::arg("name"), py::arg("version"), py::arg("precision"), py::arg("tensor_parallelism"), py::arg("pipeline_parallelism"), py::arg("context_parallelism"), py::arg("gpus_per_node"), py::arg("model_config"), py::arg("runtime_defaults") = py::none()) .def_static("parse", py::overload_cast(&tr::GptJsonConfig::parse), py::arg("json")) .def_static( "parse_file", py::overload_cast(&tr::GptJsonConfig::parse), py::arg("path")) .def_property_readonly("model_config", &tr::GptJsonConfig::getModelConfig) .def_property_readonly("name", &tr::GptJsonConfig::getName) .def_property_readonly("version", &tr::GptJsonConfig::getVersion) .def_property_readonly("precision", &tr::GptJsonConfig::getPrecision) .def_property_readonly("tensor_parallelism", &tr::GptJsonConfig::getTensorParallelism) .def_property_readonly("pipeline_parallelism", &tr::GptJsonConfig::getPipelineParallelism) .def_property_readonly("context_parallelism", &tr::GptJsonConfig::getContextParallelism) .def_property_readonly("gpus_per_node", &tr::GptJsonConfig::getGpusPerNode) .def_property_readonly("world_size", &tr::GptJsonConfig::getWorldSize) .def_property_readonly("runtime_defaults", &tr::GptJsonConfig::getRuntimeDefaults) .def("engine_filename", py::overload_cast( &tr::GptJsonConfig::engineFilename, py::const_), py::arg("world_config"), py::arg("model")) .def("engine_filename", py::overload_cast(&tr::GptJsonConfig::engineFilename, py::const_), py::arg("world_config")); py::enum_(m, "LlmRequestState") .value("UNKNOWN", tb::LlmRequestState::kUNKNOWN) .value("ENCODER_INIT", tb::LlmRequestState::kENCODER_INIT) .value("CONTEXT_INIT", tb::LlmRequestState::kCONTEXT_INIT) .value("GENERATION_IN_PROGRESS", tb::LlmRequestState::kGENERATION_IN_PROGRESS) .value("GENERATION_TO_COMPLETE", tb::LlmRequestState::kGENERATION_TO_COMPLETE) .value("GENERATION_COMPLETE", tb::LlmRequestState::kGENERATION_COMPLETE) .value("DISAGG_GENERATION_INIT", tb::LlmRequestState::kDISAGG_GENERATION_INIT) .value("DISAGG_CONTEXT_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS) .value("DISAGG_CONTEXT_COMPLETE", tb::LlmRequestState::kDISAGG_CONTEXT_COMPLETE) .value("DISAGG_GENERATION_TRANS_IN_PROGRESS", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS) .value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE) .value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); py::enum_(m, "TrtGptModelType") .value("V1", tb::TrtGptModelType::V1) .value("InflightBatching", tb::TrtGptModelType::InflightBatching) .value("InflightFusedBatching", tb::TrtGptModelType::InflightFusedBatching); auto gptModelParamsGetState = [&kvCacheConfigGetState](tb::TrtGptModelOptionalParams const& params) { auto kvCacheState = kvCacheConfigGetState(params.kvCacheConfig); return py::make_tuple(kvCacheState, params.enableTrtOverlap, params.deviceIds, params.normalizeLogProbs, params.enableChunkedContext, params.decodingConfig.getDecodingMode()); }; auto gptModelParamsSetState = [&kvCacheConfigSetState](py::tuple t) { auto kvCacheConfig = kvCacheConfigSetState(t[0]); return tb::TrtGptModelOptionalParams(kvCacheConfig, t[1].cast(), t[2].cast>>(), t[3].cast(), t[4].cast(), tb::PeftCacheManagerConfig{}, tensorrt_llm::executor::DecodingConfig(t[5].cast>())); }; py::class_(m, "TrtGptModelOptionalParams") .def(py::init> const&, bool, bool, tb::PeftCacheManagerConfig const&>(), py::arg_v("kv_cache_config", tbk::KvCacheConfig{}, "KvCacheConfig()"), py::arg("enable_trt_overlap") = false, py::arg("device_ids") = std::nullopt, py::arg("normalize_log_probs") = true, py::arg("enable_chunked_context") = false, py::arg_v("peft_cache_manager_config", tb::PeftCacheManagerConfig{}, "PeftCacheManagerConfig()")) .def(py::init(), py::arg("executor_config"), py::arg("is_leader_in_orch_mode") = false) .def_readwrite("kv_cache_config", &tb::TrtGptModelOptionalParams::kvCacheConfig) .def_readwrite("enable_trt_overlap", &tb::TrtGptModelOptionalParams::enableTrtOverlap) .def_readwrite("device_ids", &tb::TrtGptModelOptionalParams::deviceIds) .def_readwrite("enable_chunked_context", &tb::TrtGptModelOptionalParams::enableChunkedContext) .def_readwrite("normalize_log_probs", &tb::TrtGptModelOptionalParams::normalizeLogProbs) .def_readwrite("decoding_config", &tb::TrtGptModelOptionalParams::decodingConfig) .def_readwrite("gpu_weights_percent", &tb::TrtGptModelOptionalParams::gpuWeightsPercent) .def_readwrite("max_beam_width", &tb::TrtGptModelOptionalParams::maxBeamWidth) .def_readwrite("scheduler_config", &tb::TrtGptModelOptionalParams::schedulerConfig) .def(py::pickle(gptModelParamsGetState, gptModelParamsSetState)) .def("__eq__", &tb::TrtGptModelOptionalParams::operator==); py::class_(m, "MemoryCounters") .def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference) .def_property_readonly("gpu", &tr::MemoryCounters::getGpu) .def_property_readonly("cpu", &tr::MemoryCounters::getCpu) .def_property_readonly("pinned", &tr::MemoryCounters::getPinned) .def_property_readonly("uvm", &tr::MemoryCounters::getUVM); auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime"); auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings"); tensorrt_llm::pybind::runtime::initBindings(mInternalRuntime); auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings"); tpb::initBindings(mInternalBatchManager); tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager); tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager); tb::CacheTransceiverBindings::initBindings(mInternalBatchManager); tpb::Buffers::initBindings(mInternalBatchManager); auto mInternalAlgorithms = mInternal.def_submodule("algorithms", "Algorithms internal bindings"); tpb::algorithms::initBindings(mInternalAlgorithms); auto mUserbuffers = mInternal.def_submodule("userbuffers", "User buffers internal bindings"); tensorrt_llm::kernels::userbuffers::UserBufferBindings::initBindings(mUserbuffers); // NVLS allocators py::class_(m, "IpcNvlsHandle") .def(py::init<>()) .def_readwrite("uc_ptr", &tr::IpcNvlsHandle::uc_ptr) .def_readwrite("mc_ptr", &tr::IpcNvlsHandle::mc_ptr) .def_readwrite("size", &tr::IpcNvlsHandle::size) .def("get_ipc_ptrs", [](tr::IpcNvlsHandle& self) { return reinterpret_cast(self.ipc_uc_ptrs.data()); }); m.def("ipc_nvls_allocate", &tr::ipcNvlsAllocate, py::return_value_policy::reference); m.def("ipc_nvls_free", &tr::ipcNvlsFree); m.def("ipc_nvls_supported", &tr::ipcNvlsSupported); }