/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "request.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/tensor.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/runtime/cudaStream.h" #include #include #include #include #include #include #include #include namespace py = pybind11; namespace tle = tensorrt_llm::executor; using Tensor = tle::Tensor; using SizeType32 = tle::SizeType32; using FloatType = tle::FloatType; using VecTokens = tle::VecTokens; using IdType = tle::IdType; using VecTokenExtraIds = tle::VecTokenExtraIds; namespace tensorrt_llm::pybind::executor { void initRequestBindings(pybind11::module_& m) { py::enum_(m, "RequestType") .value("REQUEST_TYPE_CONTEXT_AND_GENERATION", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION) .value("REQUEST_TYPE_CONTEXT_ONLY", tle::RequestType::REQUEST_TYPE_CONTEXT_ONLY) .value("REQUEST_TYPE_GENERATION_ONLY", tle::RequestType::REQUEST_TYPE_GENERATION_ONLY); py::enum_(m, "FinishReason") .value("NOT_FINISHED", tle::FinishReason::kNOT_FINISHED) .value("END_ID", tle::FinishReason::kEND_ID) .value("STOP_WORDS", tle::FinishReason::kSTOP_WORDS) .value("LENGTH", tle::FinishReason::kLENGTH) .value("TIMED_OUT", tle::FinishReason::kTIMED_OUT) .value("CANCELLED", tle::FinishReason::kCANCELLED); auto samplingConfigGetstate = [](tle::SamplingConfig const& self) { return py::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(), self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(), self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(), self.getFrequencyPenalty(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray()); }; auto samplingConfigSetstate = [](py::tuple const& state) { if (state.size() != 19) { throw std::runtime_error("Invalid SamplingConfig state!"); } return tle::SamplingConfig(state[0].cast(), // BeamWidth state[1].cast>(), // TopK state[2].cast>(), // TopP state[3].cast>(), // TopPMin state[4].cast>(), // TopPResetIds state[5].cast>(), // TopPDecay state[6].cast>(), // Seed state[7].cast>(), // Temperature state[8].cast>(), // MinTokens state[9].cast>(), // BeamSearchDiversityRate state[10].cast>(), // RepetitionPenalty state[11].cast>(), // PresencePenalty state[12].cast>(), // FrequencyPenalty state[13].cast>(), // LengthPenalty state[14].cast>(), // EarlyStopping state[15].cast>(), // NoRepeatNgramSize state[16].cast>(), // NumReturnSequences state[17].cast>(), // MinP state[18].cast>>() // BeamWidthArray ); }; py::class_(m, "SamplingConfig") // A modified version of constructor to accpect deprecated args randomSeed and minLength // TODO(enweiz): use the original constructor after the deprecated args are removed .def( py::init( [](tle::SizeType32 beamWidth, std::optional const& topK, std::optional const& topP, std::optional const& topPMin, std::optional const& topPResetIds, std::optional const& topPDecay, std::optional seed, std::optional const& randomSeed, std::optional const& temperature, std::optional minTokens, std::optional const& minLength, std::optional const& beamSearchDiversityRate, std::optional const& repetitionPenalty, std::optional const& presencePenalty, std::optional const& frequencyPenalty, std::optional const& lengthPenalty, std::optional const& earlyStopping, std::optional const& noRepeatNgramSize, std::optional const& numReturnSequences, std::optional const& minP, std::optional> const& beamWidthArray) { if (randomSeed.has_value()) { TLLM_LOG_WARNING("random_seed is being deprecated; please use seed instead."); if (!seed.has_value()) { seed = randomSeed; } } if (minLength.has_value()) { TLLM_LOG_WARNING("min_length is being deprecated; please use min_tokens instead."); if (!minTokens.has_value()) { minTokens = minLength; } } return std::make_unique(beamWidth, topK, topP, topPMin, topPResetIds, topPDecay, seed, temperature, minTokens, beamSearchDiversityRate, repetitionPenalty, presencePenalty, frequencyPenalty, lengthPenalty, earlyStopping, noRepeatNgramSize, numReturnSequences, minP, beamWidthArray); }), py::arg("beam_width") = 1, py::kw_only(), py::arg("top_k") = py::none(), py::arg("top_p") = py::none(), py::arg("top_p_min") = py::none(), py::arg("top_p_reset_ids") = py::none(), py::arg("top_p_decay") = py::none(), py::arg("seed") = py::none(), py::arg("random_seed") = py::none(), py::arg("temperature") = py::none(), py::arg("min_tokens") = py::none(), py::arg("min_length") = py::none(), py::arg("beam_search_diversity_rate") = py::none(), py::arg("repetition_penalty") = py::none(), py::arg("presence_penalty") = py::none(), py::arg("frequency_penalty") = py::none(), py::arg("length_penalty") = py::none(), py::arg("early_stopping") = py::none(), py::arg("no_repeat_ngram_size") = py::none(), py::arg("num_return_sequences") = py::none(), py::arg("min_p") = py::none(), py::arg("beam_width_array") = py::none()) .def_property("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth) .def_property("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK) .def_property("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP) .def_property("top_p_min", &tle::SamplingConfig::getTopPMin, &tle::SamplingConfig::setTopPMin) .def_property("top_p_reset_ids", &tle::SamplingConfig::getTopPResetIds, &tle::SamplingConfig::setTopPResetIds) .def_property("top_p_decay", &tle::SamplingConfig::getTopPDecay, &tle::SamplingConfig::setTopPDecay) .def_property("seed", &tle::SamplingConfig::getSeed, &tle::SamplingConfig::setSeed) .def_property("random_seed", &tle::SamplingConfig::getRandomSeed, &tle::SamplingConfig::setRandomSeed) .def_property("temperature", &tle::SamplingConfig::getTemperature, &tle::SamplingConfig::setTemperature) .def_property("min_tokens", &tle::SamplingConfig::getMinTokens, &tle::SamplingConfig::setMinTokens) .def_property("min_length", &tle::SamplingConfig::getMinLength, &tle::SamplingConfig::setMinLength) .def_property("beam_search_diversity_rate", &tle::SamplingConfig::getBeamSearchDiversityRate, &tle::SamplingConfig::setBeamSearchDiversityRate) .def_property("repetition_penalty", &tle::SamplingConfig::getRepetitionPenalty, &tle::SamplingConfig::setRepetitionPenalty) .def_property("presence_penalty", &tle::SamplingConfig::getPresencePenalty, [](tle::SamplingConfig& self, std::optional v) { self.setPresencePenalty(v); }) .def_property( "frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty) .def_property("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty) .def_property("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping) .def_property("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize, &tle::SamplingConfig::setNoRepeatNgramSize) .def_property("num_return_sequences", &tle::SamplingConfig::getNumReturnSequences, &tle::SamplingConfig::setNumReturnSequences) .def_property("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP) .def_property( "beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray) .def(py::pickle(samplingConfigGetstate, samplingConfigSetstate)); auto additionalModelOutputGetstate = [](tle::AdditionalModelOutput const& self) { return py::make_tuple(self.name, self.gatherContext); }; auto additionalModelOutputSetstate = [](py::tuple const& state) { if (state.size() != 2) { throw std::runtime_error("Invalid AdditionalModelOutput state!"); } return tle::AdditionalModelOutput(state[0].cast(), state[1].cast()); }; py::class_(m, "AdditionalModelOutput") .def(py::init(), py::arg("name"), py::arg("gather_context") = false) .def_readwrite("name", &tle::AdditionalModelOutput::name) .def_readwrite("gather_context", &tle::AdditionalModelOutput::gatherContext) .def(py::pickle(additionalModelOutputGetstate, additionalModelOutputSetstate)); auto outputConfigGetstate = [](tle::OutputConfig const& self) { return py::make_tuple(self.returnLogProbs, self.returnContextLogits, self.returnGenerationLogits, self.excludeInputFromOutput, self.returnEncoderOutput, self.returnPerfMetrics, self.additionalModelOutputs); }; auto outputConfigSetstate = [](py::tuple const& state) { if (state.size() != 7) { throw std::runtime_error("Invalid OutputConfig state!"); } return tle::OutputConfig(state[0].cast(), state[1].cast(), state[2].cast(), state[3].cast(), state[4].cast(), state[5].cast(), state[6].cast>>()); }; py::class_(m, "OutputConfig") .def(py::init>>(), py::arg("return_log_probs") = false, py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false, py::arg("exclude_input_from_output") = false, py::arg("return_encoder_output") = false, py::arg("return_perf_metrics") = false, py::arg("additional_model_outputs") = py::none()) .def_readwrite("return_log_probs", &tle::OutputConfig::returnLogProbs) .def_readwrite("return_context_logits", &tle::OutputConfig::returnContextLogits) .def_readwrite("return_generation_logits", &tle::OutputConfig::returnGenerationLogits) .def_readwrite("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput) .def_readwrite("return_encoder_output", &tle::OutputConfig::returnEncoderOutput) .def_readwrite("return_perf_metrics", &tle::OutputConfig::returnPerfMetrics) .def_readwrite("additional_model_outputs", &tle::OutputConfig::additionalModelOutputs) .def(py::pickle(outputConfigGetstate, outputConfigSetstate)); auto externalDraftTokensConfigGetstate = [](tle::ExternalDraftTokensConfig const& self) { return py::make_tuple(self.getTokens(), self.getLogits(), self.getAcceptanceThreshold()); }; auto externalDraftTokensConfigSetstate = [](py::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid ExternalDraftTokensConfig state!"); } return tle::ExternalDraftTokensConfig(state[0].cast(), state[1].cast>(), state[2].cast>()); }; py::class_(m, "ExternalDraftTokensConfig") .def(py::init, std::optional const&, std::optional>(), py::arg("tokens"), py::arg("logits") = py::none(), py::arg("acceptance_threshold") = py::none(), py::arg("fast_logits") = py::none()) .def_property_readonly("tokens", &tle::ExternalDraftTokensConfig::getTokens) .def_property_readonly("logits", &tle::ExternalDraftTokensConfig::getLogits) .def_property_readonly("acceptance_threshold", &tle::ExternalDraftTokensConfig::getAcceptanceThreshold) .def(py::pickle(externalDraftTokensConfigGetstate, externalDraftTokensConfigSetstate)) .def_property_readonly("fast_logits", &tle::ExternalDraftTokensConfig::getFastLogits); auto promptTuningConfigGetstate = [](tle::PromptTuningConfig const& self) { return py::make_tuple(self.getEmbeddingTable(), self.getInputTokenExtraIds()); }; auto promptTuningConfigSetstate = [](py::tuple const& state) { if (state.size() != 2) { throw std::runtime_error("Invalid PromptTuningConfig state!"); } return tle::PromptTuningConfig(state[0].cast(), state[1].cast>()); }; py::class_(m, "PromptTuningConfig") .def(py::init>(), py::arg("embedding_table"), py::arg("input_token_extra_ids") = py::none()) .def_property_readonly("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable) .def_property_readonly("input_token_extra_ids", &tle::PromptTuningConfig::getInputTokenExtraIds) .def(py::pickle(promptTuningConfigGetstate, promptTuningConfigSetstate)); auto loraConfigGetstate = [](tle::LoraConfig const& self) { return py::make_tuple(self.getTaskId(), self.getWeights(), self.getConfig()); }; auto loraConfigSetstate = [](py::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid LoraConfig state!"); } return tle::LoraConfig( state[0].cast(), state[1].cast>(), state[2].cast>()); }; py::class_(m, "LoraConfig") .def(py::init, std::optional>(), py::arg("task_id"), py::arg("weights") = py::none(), py::arg("config") = py::none()) .def_property_readonly("task_id", &tle::LoraConfig::getTaskId) .def_property_readonly("weights", &tle::LoraConfig::getWeights) .def_property_readonly("config", &tle::LoraConfig::getConfig) .def(py::pickle(loraConfigGetstate, loraConfigSetstate)); auto MropeConfigGetstate = [](tle::MropeConfig const& self) { return py::make_tuple(self.getMRopeRotaryCosSin(), self.getMRopePositionDeltas()); }; auto MropeConfigSetstate = [](py::tuple const& state) { if (state.size() != 2) { throw std::runtime_error("Invalid MropeConfig state!"); } return tle::MropeConfig(state[0].cast(), state[1].cast()); }; py::class_(m, "MropeConfig") .def(py::init(), py::arg("mrope_rotary_cos_sin"), py::arg("mrope_position_deltas")) .def_property_readonly("mrope_rotary_cos_sin", &tle::MropeConfig::getMRopeRotaryCosSin) .def_property_readonly("mrope_position_deltas", &tle::MropeConfig::getMRopePositionDeltas) .def(py::pickle(MropeConfigGetstate, MropeConfigSetstate)); auto lookaheadDecodingConfigGetstate = [](tle::LookaheadDecodingConfig const& self) { return py::make_tuple(self.getWindowSize(), self.getNgramSize(), self.getVerificationSetSize()); }; auto lookaheadDecodingConfigSetstate = [](py::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid LookaheadDecodingConfig state!"); } return tle::LookaheadDecodingConfig( state[0].cast(), state[1].cast(), state[2].cast()); }; py::class_(m, "LookaheadDecodingConfig") .def(py::init(), py::arg("max_window_size"), py::arg("max_ngram_size"), py::arg("max_verification_set_size")) .def_property_readonly("max_window_size", &tle::LookaheadDecodingConfig::getWindowSize) .def_property_readonly("max_ngram_size", &tle::LookaheadDecodingConfig::getNgramSize) .def_property_readonly("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize) .def("calculate_speculative_resource", &tle::LookaheadDecodingConfig::calculateSpeculativeResource) .def_static( "calculate_speculative_resource_tuple", &tle::LookaheadDecodingConfig::calculateSpeculativeResourceTuple) .def(py::pickle(lookaheadDecodingConfigGetstate, lookaheadDecodingConfigSetstate)) .def_static("get_default_lookahead_decoding_window", []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingWindow; }) .def_static("get_default_lookahead_decoding_ngram", []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingNgram; }) .def_static("get_default_lookahead_decoding_verification_set", []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingVerificationSet; }); auto TokenRangeRetentionConfigGetstate = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig const& self) { return py::make_tuple(self.tokenStart, self.tokenEnd, self.priority, self.durationMs); }; auto TokenRangeRetentionConfigSetstate = [](py::tuple const& state) { if (state.size() != 4) { throw std::runtime_error("Invalid state!"); } return tle::KvCacheRetentionConfig::TokenRangeRetentionConfig(state[0].cast(), state[1].cast>(), state[2].cast(), state[3].cast>()); }; auto kvCacheRetentionConfigGetstate = [](tle::KvCacheRetentionConfig const& self) { return py::make_tuple( self.getTokenRangeRetentionConfigs(), self.getDecodeRetentionPriority(), self.getDecodeDurationMs()); }; auto kvCacheRetentionConfigSetstate = [](py::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid state!"); } return tle::KvCacheRetentionConfig( state[0].cast>(), state[1].cast(), state[2].cast>()); }; auto kvCacheRetentionConfig = py::class_(m, "KvCacheRetentionConfig"); py::class_( kvCacheRetentionConfig, "TokenRangeRetentionConfig") .def(py::init, tle::RetentionPriority, std::optional>(), py::arg("token_start"), py::arg("token_end"), py::arg("priority"), py::arg("duration_ms") = py::none()) .def_readwrite("token_start", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenStart) .def_readwrite("token_end", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenEnd) .def_readwrite("priority", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::priority) .def_readwrite("duration_ms", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::durationMs) .def(py::pickle(TokenRangeRetentionConfigGetstate, TokenRangeRetentionConfigSetstate)) .def("__eq__", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::operator==); // There's a circular dependency between the declaration of the TokenRangeRetentionPriority and // KvCacheRetentionConfig bindings. Defer definition of the KvCacheRetentionConfig bindings until the // TokenRangeRetentionPriority bindings have been defined. kvCacheRetentionConfig .def(py::init, tle::RetentionPriority, std::optional>(), py::arg("token_range_retention_configs"), py::arg("decode_retention_priority") = tle::KvCacheRetentionConfig::kDefaultRetentionPriority, py::arg("decode_duration_ms") = py::none()) .def_property_readonly( "token_range_retention_configs", &tle::KvCacheRetentionConfig::getTokenRangeRetentionConfigs) .def_property_readonly("decode_retention_priority", &tle::KvCacheRetentionConfig::getDecodeRetentionPriority) .def_property_readonly("decode_duration_ms", &tle::KvCacheRetentionConfig::getDecodeDurationMs) .def(py::pickle(kvCacheRetentionConfigGetstate, kvCacheRetentionConfigSetstate)) .def("__eq__", &tle::KvCacheRetentionConfig::operator==); auto ContextPhaseParamsGetState = [](tle::ContextPhaseParams const& self) { if (self.getState() != nullptr) { auto serializedState = self.getSerializedState(); return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens()); } return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens()); }; auto ContextPhaseParamsSetState = [](py::tuple const& state) { if (state.size() != 4) { throw std::runtime_error("Invalid ContextPhaseParams state!"); } if (!state[2].is_none()) { auto opaque_state = state[2].cast(); auto opaque_state_str_view = std::string_view(opaque_state.cast()); return std::make_unique(state[0].cast(), state[1].cast(), std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), state[3].cast>()); } return std::make_unique(state[0].cast(), state[1].cast(), state[3].cast>()); }; py::class_(m, "ContextPhaseParams") .def(py::init( [](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id, std::optional const& opaque_state, std::optional const& draft_tokens) { if (opaque_state) { auto opaque_state_str_view = std::string_view(opaque_state.value().cast()); return std::make_unique(first_gen_tokens, req_id, std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens); } return std::make_unique(first_gen_tokens, req_id, draft_tokens); })) .def_property_readonly("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens) .def_property_readonly("draft_tokens", &tle::ContextPhaseParams::getDraftTokens) .def_property_readonly("req_id", &tle::ContextPhaseParams::getReqId) .def_property_readonly("opaque_state", [](tle::ContextPhaseParams const& self) { std::optional opaque_state{std::nullopt}; if (self.getState() != nullptr) { auto serializedState = self.getSerializedState(); opaque_state = py::bytes(serializedState.data(), serializedState.size()); } return opaque_state; }) .def(py::pickle(ContextPhaseParamsGetState, ContextPhaseParamsSetState)); auto EagleDecodingConfigGetstate = [](tle::EagleConfig const& self) { return py::make_tuple(self.getEagleChoices(), self.isGreedySampling(), self.getPosteriorThreshold(), self.useDynamicTree(), self.getDynamicTreeMaxTopK()); }; auto EagleDecodingConfigSetstate = [](py::tuple const& state) { if (state.size() != 5) { throw std::runtime_error("Invalid EagleConfig state!"); } return tle::EagleConfig(state[0].cast(), state[1].cast(), state[2].cast>(), state[3].cast(), state[4].cast>()); }; py::class_(m, "EagleConfig") .def(py::init, bool, std::optional, bool, std::optional>(), py::arg("eagle_choices") = py::none(), py::arg("greedy_sampling") = true, py::arg("posterior_threshold") = py::none(), py::arg("use_dynamic_tree") = false, py::arg("dynamic_tree_max_topK") = py::none()) .def_property_readonly("eagle_choices", &tle::EagleConfig::getEagleChoices) .def_property_readonly("greedy_sampling", &tle::EagleConfig::isGreedySampling) .def_property_readonly("posterior_threshold", &tle::EagleConfig::getPosteriorThreshold) .def_property_readonly("use_dynamic_tree", &tle::EagleConfig::useDynamicTree) .def_property_readonly("dynamic_tree_max_topK", &tle::EagleConfig::getDynamicTreeMaxTopK) .def(py::pickle(EagleDecodingConfigGetstate, EagleDecodingConfigSetstate)); // Guided decoding params auto pyGuidedDecodingParams = py::class_(m, "GuidedDecodingParams"); py::enum_(pyGuidedDecodingParams, "GuideType") .value("JSON", tle::GuidedDecodingParams::GuideType::kJSON) .value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA) .value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX) .value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR); auto guidedDecodingParamsGetstate = [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); }; auto guidedDecodingParamsSetstate = [](py::tuple state) { if (state.size() != 2) { throw std::runtime_error("Invalid GuidedDecodingParams state!"); } return tle::GuidedDecodingParams( state[0].cast(), state[1].cast>()); }; pyGuidedDecodingParams .def(py::init>(), py::arg("guide_type"), py::arg("guide") = py::none()) .def_property_readonly("guide_type", &tle::GuidedDecodingParams::getGuideType) .def_property_readonly("guide", &tle::GuidedDecodingParams::getGuide) .def(py::pickle(guidedDecodingParamsGetstate, guidedDecodingParamsSetstate)); auto requestGetstate = [](tle::Request const& self) { return py::make_tuple(self.getInputTokenIds(), self.getMaxTokens(), self.getStreaming(), self.getSamplingConfig(), self.getOutputConfig(), self.getEndId(), self.getPadId(), self.getPositionIds(), self.getBadWords(), self.getStopWords(), self.getEmbeddingBias(), self.getExternalDraftTokensConfig(), self.getPromptTuningConfig(), self.getMultimodalEmbedding(), self.getMropeConfig(), self.getLoraConfig(), self.getLookaheadConfig(), self.getKvCacheRetentionConfig(), self.getLogitsPostProcessorName(), self.getLogitsPostProcessor(), self.getEncoderInputTokenIds(), self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), self.getGuidedDecodingParams()); }; auto requestSetstate = [](py::tuple const& state) { if (state.size() != 32) { throw std::runtime_error("Invalid Request state!"); } return std::make_unique(state[0].cast(), state[1].cast(), state[2].cast(), state[3].cast(), state[4].cast(), state[5].cast>(), state[6].cast>(), state[7].cast>>(), state[8].cast>>(), state[9].cast>>(), state[10].cast>(), state[11].cast>(), state[12].cast>(), state[13].cast>(), state[14].cast>(), state[15].cast>(), state[16].cast>(), state[17].cast>(), state[18].cast>(), state[19].cast>(), state[20].cast>(), state[21].cast>(), state[22].cast(), state[23].cast(), state[24].cast(), state[25].cast>(), state[26].cast>(), state[27].cast>(), state[28].cast>(), 1, state[29].cast>(), state[30].cast>(), state[31].cast>()); }; py::class_ request(m, "Request", pybind11::dynamic_attr()); request // A modified version of constructor to accpect deprecated args maxNewTokens // TODO(enweiz): use the original constructor after the deprecated args are removed .def(py::init( [](tle::VecTokens const& inputTokenIds, std::optional maxTokens, std::optional maxNewTokens, bool streaming, tle::SamplingConfig const& samplingConfig, tle::OutputConfig const& outputConfig, std::optional const& endId, std::optional const& padId, std::optional> const& positionIds, std::optional> const& badWords, std::optional> const& stopWords, std::optional const& embeddingBias, std::optional const& externalDraftTokensConfig, std::optional const& pTuningConfig, std::optional const& multimodalEmbedding, std::optional const& mRopeConfig, std::optional const& loraConfig, std::optional lookaheadConfig, std::optional const& kvCacheRetentionConfig, std::optional const& logitsPostProcessorName, std::optional const& logitsPostProcessor, std::optional const& encoderInputTokenIds, std::optional clientId, bool returnAllGeneratedTokens, tle::PriorityType priority, tle::RequestType type, std::optional const& contextPhaseParams, std::optional const& encoderInputFeatures, std::optional encoderOutputLength, std::optional const& crossAttentionMask, std::optional const& eagleConfig, std::optional const& skipCrossAttnBlocks, std::optional const& guidedDecodingParams, std::optional const& languageAdapterUid) { if (maxNewTokens.has_value()) { TLLM_LOG_WARNING("max_new_tokens is being deprecated; please use max_tokens instead."); if (!maxTokens.has_value()) { maxTokens = maxNewTokens; } } TLLM_CHECK_WITH_INFO(maxTokens.has_value(), "missing required argument max_tokens"); return std::make_unique(inputTokenIds, maxTokens.value(), streaming, samplingConfig, outputConfig, endId, padId, positionIds, badWords, stopWords, embeddingBias, externalDraftTokensConfig, pTuningConfig, multimodalEmbedding, mRopeConfig, loraConfig, lookaheadConfig, kvCacheRetentionConfig, logitsPostProcessorName, logitsPostProcessor, encoderInputTokenIds, clientId, returnAllGeneratedTokens, priority, type, contextPhaseParams, encoderInputFeatures, encoderOutputLength, crossAttentionMask, 1, eagleConfig, skipCrossAttnBlocks, guidedDecodingParams, languageAdapterUid); }), py::arg("input_token_ids"), py::kw_only(), py::arg("max_tokens") = py::none(), py::arg("max_new_tokens") = py::none(), py::arg("streaming") = false, py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"), py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(), py::arg("pad_id") = py::none(), py::arg("position_ids") = py::none(), py::arg("bad_words") = py::none(), py::arg("stop_words") = py::none(), py::arg("embedding_bias") = py::none(), py::arg("external_draft_tokens_config") = py::none(), py::arg("prompt_tuning_config") = py::none(), py::arg("multimodal_embedding") = py::none(), py::arg("mrope_config") = py::none(), py::arg("lora_config") = py::none(), py::arg("lookahead_config") = py::none(), py::arg("kv_cache_retention_config") = py::none(), py::arg("logits_post_processor_name") = py::none(), py::arg("logits_post_processor") = py::none(), py::arg("encoder_input_token_ids") = py::none(), py::arg("client_id") = py::none(), py::arg("return_all_generated_tokens") = false, py::arg("priority") = tle::Request::kDefaultPriority, py::arg_v("type", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION, "RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION"), py::arg("context_phase_params") = py::none(), py::arg("encoder_input_features") = py::none(), py::arg("encoder_output_length") = py::none(), py::arg("cross_attention_mask") = py::none(), py::arg("eagle_config") = py::none(), py::arg("skip_cross_attn_blocks") = py::none(), py::arg("guided_decoding_params") = py::none(), py::arg("language_adapter_uid") = py::none()) .def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds) .def_property_readonly("max_tokens", &tle::Request::getMaxTokens) .def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens) .def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) .def_property("sampling_config", &tle::Request::getSamplingConfig, &tle::Request::setSamplingConfig) .def_property("output_config", &tle::Request::getOutputConfig, &tle::Request::setOutputConfig) .def_property("end_id", &tle::Request::getEndId, &tle::Request::setEndId) .def_property("pad_id", &tle::Request::getPadId, &tle::Request::setPadId) .def_property("position_ids", &tle::Request::getPositionIds, &tle::Request::setPositionIds) .def_property("bad_words", &tle::Request::getBadWords, &tle::Request::setBadWords) .def_property("stop_words", &tle::Request::getStopWords, &tle::Request::setStopWords) .def_property("embedding_bias", &tle::Request::getEmbeddingBias, &tle::Request::setEmbeddingBias) .def_property("external_draft_tokens_config", &tle::Request::getExternalDraftTokensConfig, &tle::Request::setExternalDraftTokensConfig) .def_property( "prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig) .def_property( "multimodal_embedding", &tle::Request::getMultimodalEmbedding, &tle::Request::setMultimodalEmbedding) .def_property("mrope_config", &tle::Request::getMropeConfig, &tle::Request::setMropeConfig) .def_property("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig) .def_property("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig) .def_property("kv_cache_retention_config", &tle::Request::getKvCacheRetentionConfig, &tle::Request::setKvCacheRetentionConfig) .def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName, &tle::Request::setLogitsPostProcessorName) .def_property( "logits_post_processor", &tle::Request::getLogitsPostProcessor, &tle::Request::setLogitsPostProcessor) .def_property( "encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds) .def_property("client_id", &tle::Request::getClientId, &tle::Request::setClientId) .def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens, &tle::Request::setReturnAllGeneratedTokens) .def_property("request_type", &tle::Request::getRequestType, &tle::Request::setRequestType) .def_property( "encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures) .def_property( "cross_attention_mask", &tle::Request::getCrossAttentionMask, &tle::Request::setCrossAttentionMask) .def_property("eagle_config", &tle::Request::getEagleConfig, &tle::Request::setEagleConfig) .def_property( "skip_cross_attn_blocks", &tle::Request::getSkipCrossAttnBlocks, &tle::Request::setSkipCrossAttnBlocks) .def_property( "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) .def_property("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) .def_property( "context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) .def(py::pickle(requestGetstate, requestSetstate)); request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName; py::class_(m, "SpeculativeDecodingFastLogitsInfo") .def(py::init<>()) .def_readwrite("draft_request_id", &tle::SpeculativeDecodingFastLogitsInfo::draftRequestId) .def_readwrite("draft_participant_id", &tle::SpeculativeDecodingFastLogitsInfo::draftParticipantId) .def("to_tensor", &tle::SpeculativeDecodingFastLogitsInfo::toTensor); auto requestPerfMetrics = py::class_(m, "RequestPerfMetrics"); py::class_(m, "TimingMetrics") .def(py::init<>()) .def_readwrite("arrival_time", &tle::RequestPerfMetrics::TimingMetrics::arrivalTime) .def_readwrite("first_scheduled_time", &tle::RequestPerfMetrics::TimingMetrics::firstScheduledTime) .def_readwrite("first_token_time", &tle::RequestPerfMetrics::TimingMetrics::firstTokenTime) .def_readwrite("last_token_time", &tle::RequestPerfMetrics::TimingMetrics::lastTokenTime) .def_readwrite("kv_cache_transfer_start", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferStart) .def_readwrite("kv_cache_transfer_end", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferEnd); py::class_(m, "KvCacheMetrics") .def(py::init<>()) .def_readwrite("num_total_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numTotalAllocatedBlocks) .def_readwrite("num_new_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numNewAllocatedBlocks) .def_readwrite("num_reused_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numReusedBlocks) .def_readwrite("num_missed_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numMissedBlocks) .def_readwrite("kv_cache_hit_rate", &tle::RequestPerfMetrics::KvCacheMetrics::kvCacheHitRate); py::class_(m, "SpeculativeDecodingMetrics") .def(py::init<>()) .def_readwrite("acceptance_rate", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::acceptanceRate) .def_readwrite("total_accepted_draft_tokens", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalAcceptedDraftTokens) .def_readwrite("total_draft_tokens", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalDraftTokens); // There's a circular dependency between the declaration of the TimingMetrics and RequestPerfMetrics bindings. // Defer definition of the RequestPerfMetrics bindings until the TimingMetrics have been defined. requestPerfMetrics.def(py::init<>()) .def_readwrite("timing_metrics", &tle::RequestPerfMetrics::timingMetrics) .def_readwrite("kv_cache_metrics", &tle::RequestPerfMetrics::kvCacheMetrics) .def_readwrite("speculative_decoding", &tle::RequestPerfMetrics::speculativeDecoding) .def_readwrite("first_iter", &tle::RequestPerfMetrics::firstIter) .def_readwrite("last_iter", &tle::RequestPerfMetrics::lastIter) .def_readwrite("iter", &tle::RequestPerfMetrics::iter); py::class_(m, "AdditionalOutput") .def(py::init([](std::string const& name, tle::Tensor const& output) { return std::make_unique(name, output); })) .def_readwrite("name", &tle::AdditionalOutput::name) .def_readwrite("output", &tle::AdditionalOutput::output); auto resultSetstate = [](py::tuple const& state) { if (state.size() != 12) { throw std::runtime_error("Invalid Request state!"); } tle::Result result; result.isFinal = state[0].cast(); result.outputTokenIds = state[1].cast>(); result.cumLogProbs = state[2].cast>>(); result.logProbs = state[3].cast>>>(); result.contextLogits = state[4].cast>(); result.generationLogits = state[5].cast>(); result.encoderOutput = state[6].cast>(); result.finishReasons = state[7].cast>(); result.sequenceIndex = state[8].cast(); result.isSequenceFinal = state[9].cast(); result.decodingIter = state[10].cast(); result.contextPhaseParams = state[11].cast>(); return std::make_unique(result); }; auto resultGetstate = [](tle::Result const& self) { return py::make_tuple(self.isFinal, self.outputTokenIds, self.cumLogProbs, self.logProbs, self.contextLogits, self.generationLogits, self.encoderOutput, self.finishReasons, self.sequenceIndex, self.isSequenceFinal, self.decodingIter, self.contextPhaseParams); }; py::class_(m, "Result") .def(py::init<>()) .def_readwrite("is_final", &tle::Result::isFinal) .def_readwrite("output_token_ids", &tle::Result::outputTokenIds) .def_readwrite("cum_log_probs", &tle::Result::cumLogProbs) .def_readwrite("log_probs", &tle::Result::logProbs) .def_readwrite("context_logits", &tle::Result::contextLogits) .def_readwrite("generation_logits", &tle::Result::generationLogits) .def_readwrite("spec_dec_fast_logits_info", &tle::Result::specDecFastLogitsInfo) .def_readwrite("encoder_output", &tle::Result::encoderOutput) .def_readwrite("finish_reasons", &tle::Result::finishReasons) .def_readwrite("sequence_index", &tle::Result::sequenceIndex) .def_readwrite("is_sequence_final", &tle::Result::isSequenceFinal) .def_readwrite("decoding_iter", &tle::Result::decodingIter) .def_readwrite("context_phase_params", &tle::Result::contextPhaseParams) .def_readwrite("request_perf_metrics", &tle::Result::requestPerfMetrics) .def_readwrite("additional_outputs", &tle::Result::additionalOutputs) .def_readwrite("context_phase_params", &tle::Result::contextPhaseParams) .def(py::pickle(resultGetstate, resultSetstate)); auto responseGetstate = [](tle::Response const& self) { return py::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); }; auto responseSetstate = [](py::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid Request state!"); } return std::make_unique( state[0].cast(), state[1].cast(), state[2].cast()); }; py::class_(m, "Response") .def(py::init>(), py::arg("request_id"), py::arg("error_msg"), py::arg("client_id") = std::nullopt) .def(py::init>(), py::arg("request_id"), py::arg("result"), py::arg("client_id") = std::nullopt) .def_property_readonly("request_id", &tle::Response::getRequestId) .def_property_readonly("client_id", &tle::Response::getClientId) .def("has_error", &tle::Response::hasError) .def_property_readonly("error_msg", &tle::Response::getErrorMsg) .def_property_readonly("result", &tle::Response::getResult) .def("clear_context_logits", [](tle::Response& self) { if (!self.hasError()) { auto& result = const_cast(self.getResult()); result.contextLogits.reset(); } }) .def("clear_generation_logits", [](tle::Response& self) { if (!self.hasError()) { auto& result = const_cast(self.getResult()); result.generationLogits.reset(); } }) .def(py::pickle(responseGetstate, responseSetstate)); } } // namespace tensorrt_llm::pybind::executor