/* * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "request.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/executor/tensor.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/nanobind/common/customCasters.h" #include "tensorrt_llm/runtime/cudaStream.h" #include #include #include #include #include #include #include #include #include #include #include #include namespace nb = nanobind; namespace tle = tensorrt_llm::executor; using Tensor = tle::Tensor; using SizeType32 = tle::SizeType32; using FloatType = tle::FloatType; using VecTokens = tle::VecTokens; using IdType = tle::IdType; using VecTokenExtraIds = tle::VecTokenExtraIds; namespace tensorrt_llm::nanobind::executor { void initRequestBindings(nb::module_& m) { nb::enum_(m, "RequestType") .value("REQUEST_TYPE_CONTEXT_AND_GENERATION", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION) .value("REQUEST_TYPE_CONTEXT_ONLY", tle::RequestType::REQUEST_TYPE_CONTEXT_ONLY) .value("REQUEST_TYPE_GENERATION_ONLY", tle::RequestType::REQUEST_TYPE_GENERATION_ONLY); nb::enum_(m, "FinishReason") .value("NOT_FINISHED", tle::FinishReason::kNOT_FINISHED) .value("END_ID", tle::FinishReason::kEND_ID) .value("STOP_WORDS", tle::FinishReason::kSTOP_WORDS) .value("LENGTH", tle::FinishReason::kLENGTH) .value("TIMED_OUT", tle::FinishReason::kTIMED_OUT) .value("CANCELLED", tle::FinishReason::kCANCELLED); nb::enum_(m, "KvCacheTransferMode") .value("DRAM", tle::KvCacheTransferMode::DRAM) .value("GDS", tle::KvCacheTransferMode::GDS) .value("POSIX_DEBUG_FALLBACK", tle::KvCacheTransferMode::POSIX_DEBUG_FALLBACK); auto samplingConfigGetstate = [](tle::SamplingConfig const& self) { return nb::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(), self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(), self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(), self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(), self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray()); }; auto samplingConfigSetstate = [](tle::SamplingConfig& samplingConfig, nb::tuple const& state) { if (state.size() != 20) { throw std::runtime_error("Invalid SamplingConfig state!"); } new (&samplingConfig) tle::SamplingConfig(nb::cast(state[0]), // BeamWidth nb::cast>(state[1]), // TopK nb::cast>(state[2]), // TopP nb::cast>(state[3]), // TopPMin nb::cast>(state[4]), // TopPResetIds nb::cast>(state[5]), // TopPDecay nb::cast>(state[6]), // Seed nb::cast>(state[7]), // Temperature nb::cast>(state[8]), // MinTokens nb::cast>(state[9]), // BeamSearchDiversityRate nb::cast>(state[10]), // RepetitionPenalty nb::cast>(state[11]), // PresencePenalty nb::cast>(state[12]), // FrequencyPenalty nb::cast>(state[13]), // PromptIgnoreLength nb::cast>(state[14]), // LengthPenalty nb::cast>(state[15]), // EarlyStopping nb::cast>(state[16]), // NoRepeatNgramSize nb::cast>(state[17]), // NumReturnSequences nb::cast>(state[18]), // MinP nb::cast>>(state[19]) // BeamWidthArray ); }; nb::class_(m, "SamplingConfig") .def(nb::init const&, // beamWidth std::optional const&, // topP std::optional const&, // topPMin std::optional const&, // topPResetIds std::optional const&, // topPDecay std::optional const&, // seed std::optional const&, // temperature std::optional const&, // minTokens std::optional const&, // beamSearchDiversityRate std::optional const&, // repetitionPenalty std::optional const&, // presencePenalty std::optional const&, // frequencyPenalty std::optional const&, // promptIgnoreLength std::optional const&, // lengthPenalty std::optional const&, // earlyStopping std::optional const&, // noRepeatNgramSize std::optional const&, // numReturnSequences std::optional const&, // minP std::optional> const& // beamWidthArray >(), // clang-format off nb::arg("beam_width") = 1, nb::kw_only(), nb::arg("top_k") = nb::none(), nb::arg("top_p") = nb::none(), nb::arg("top_p_min") = nb::none(), nb::arg("top_p_reset_ids") = nb::none(), nb::arg("top_p_decay") = nb::none(), nb::arg("seed") = nb::none(), nb::arg("temperature") = nb::none(), nb::arg("min_tokens") = nb::none(), nb::arg("beam_search_diversity_rate") = nb::none(), nb::arg("repetition_penalty") = nb::none(), nb::arg("presence_penalty") = nb::none(), nb::arg("frequency_penalty") = nb::none(), nb::arg("prompt_ignore_length") = nb::none(), nb::arg("length_penalty") = nb::none(), nb::arg("early_stopping") = nb::none(), nb::arg("no_repeat_ngram_size") = nb::none(), nb::arg("num_return_sequences") = nb::none(), nb::arg("min_p") = nb::none(), nb::arg("beam_width_array") = nb::none()) // clang-format on .def_prop_rw("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth) .def_prop_rw("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK) .def_prop_rw("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP) .def_prop_rw("top_p_min", &tle::SamplingConfig::getTopPMin, &tle::SamplingConfig::setTopPMin) .def_prop_rw("top_p_reset_ids", &tle::SamplingConfig::getTopPResetIds, &tle::SamplingConfig::setTopPResetIds) .def_prop_rw("top_p_decay", &tle::SamplingConfig::getTopPDecay, &tle::SamplingConfig::setTopPDecay) .def_prop_rw("seed", &tle::SamplingConfig::getSeed, &tle::SamplingConfig::setSeed) .def_prop_rw("temperature", &tle::SamplingConfig::getTemperature, &tle::SamplingConfig::setTemperature) .def_prop_rw("min_tokens", &tle::SamplingConfig::getMinTokens, &tle::SamplingConfig::setMinTokens) .def_prop_rw("beam_search_diversity_rate", &tle::SamplingConfig::getBeamSearchDiversityRate, &tle::SamplingConfig::setBeamSearchDiversityRate) .def_prop_rw("repetition_penalty", &tle::SamplingConfig::getRepetitionPenalty, &tle::SamplingConfig::setRepetitionPenalty) .def_prop_rw("presence_penalty", &tle::SamplingConfig::getPresencePenalty, [](tle::SamplingConfig& self, std::optional v) { self.setPresencePenalty(v); }) .def_prop_rw( "frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty) .def_prop_rw("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength, &tle::SamplingConfig::setPromptIgnoreLength) .def_prop_rw("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty) .def_prop_rw("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping) .def_prop_rw("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize, &tle::SamplingConfig::setNoRepeatNgramSize) .def_prop_rw("num_return_sequences", &tle::SamplingConfig::getNumReturnSequences, &tle::SamplingConfig::setNumReturnSequences) .def_prop_rw("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP) .def_prop_rw( "beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray) .def("__getstate__", samplingConfigGetstate) .def("__setstate__", samplingConfigSetstate); auto additionalModelOutputGetstate = [](tle::AdditionalModelOutput const& self) { return nb::make_tuple(self.name, self.gatherContext); }; auto additionalModelOutputSetstate = [](tle::AdditionalModelOutput& additionalModelOutput, nb::tuple const& state) { if (state.size() != 2) { throw std::runtime_error("Invalid AdditionalModelOutput state!"); } new (&additionalModelOutput) tle::AdditionalModelOutput(nb::cast(state[0]), nb::cast(state[1])); }; nb::class_(m, "AdditionalModelOutput") .def(nb::init(), nb::arg("name"), nb::arg("gather_context") = false) .def_rw("name", &tle::AdditionalModelOutput::name) .def_rw("gather_context", &tle::AdditionalModelOutput::gatherContext) .def("__getstate__", additionalModelOutputGetstate) .def("__setstate__", additionalModelOutputSetstate); auto outputConfigGetstate = [](tle::OutputConfig const& self) { return nb::make_tuple(self.returnLogProbs, self.returnContextLogits, self.returnGenerationLogits, self.excludeInputFromOutput, self.returnEncoderOutput, self.returnPerfMetrics, self.additionalModelOutputs); }; auto outputConfigSetstate = [](tle::OutputConfig& outputConfig, nb::tuple const& state) { if (state.size() != 7) { throw std::runtime_error("Invalid OutputConfig state!"); } new (&outputConfig) tle::OutputConfig(nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), nb::cast(state[5]), nb::cast>>(state[6])); }; nb::class_(m, "OutputConfig") .def( "__init__", [](tle::OutputConfig& self, std::optional return_log_probs, std::optional return_context_logits, std::optional return_generation_logits, std::optional exclude_input_from_output, std::optional return_encoder_output, std::optional return_perf_metrics, std::optional> additional_model_outputs) { new (&self) tle::OutputConfig(return_log_probs.value_or(false), return_context_logits.value_or(false), return_generation_logits.value_or(false), exclude_input_from_output.value_or(false), return_encoder_output.value_or(false), return_perf_metrics.value_or(false), additional_model_outputs); }, nb::arg("return_log_probs") = nb::none(), nb::arg("return_context_logits") = nb::none(), nb::arg("return_generation_logits") = nb::none(), nb::arg("exclude_input_from_output") = nb::none(), nb::arg("return_encoder_output") = nb::none(), nb::arg("return_perf_metrics") = nb::none(), nb::arg("additional_model_outputs") = nb::none()) .def_rw("return_log_probs", &tle::OutputConfig::returnLogProbs) .def_rw("return_context_logits", &tle::OutputConfig::returnContextLogits) .def_rw("return_generation_logits", &tle::OutputConfig::returnGenerationLogits) .def_rw("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput) .def_rw("return_encoder_output", &tle::OutputConfig::returnEncoderOutput) .def_rw("return_perf_metrics", &tle::OutputConfig::returnPerfMetrics) .def_rw("additional_model_outputs", &tle::OutputConfig::additionalModelOutputs) .def("__getstate__", outputConfigGetstate) .def("__setstate__", outputConfigSetstate); auto externalDraftTokensConfigGetstate = [](tle::ExternalDraftTokensConfig const& self) { return nb::make_tuple(self.getTokens(), self.getLogits(), self.getAcceptanceThreshold()); }; auto externalDraftTokensConfigSetstate = [](tle::ExternalDraftTokensConfig& externalDraftTokensConfig, nb::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid ExternalDraftTokensConfig state!"); } new (&externalDraftTokensConfig) tle::ExternalDraftTokensConfig(nb::cast(state[0]), nb::cast>(state[1]), nb::cast>(state[2])); }; nb::class_(m, "ExternalDraftTokensConfig") .def(nb::init, std::optional const&, std::optional>(), nb::arg("tokens"), nb::arg("logits") = nb::none(), nb::arg("acceptance_threshold") = nb::none(), nb::arg("fast_logits") = nb::none()) .def_prop_ro("tokens", &tle::ExternalDraftTokensConfig::getTokens) .def_prop_ro("logits", &tle::ExternalDraftTokensConfig::getLogits) .def_prop_ro("acceptance_threshold", &tle::ExternalDraftTokensConfig::getAcceptanceThreshold) .def("__getstate__", externalDraftTokensConfigGetstate) .def("__setstate__", externalDraftTokensConfigSetstate) .def_prop_ro("fast_logits", &tle::ExternalDraftTokensConfig::getFastLogits); auto promptTuningConfigGetstate = [](tle::PromptTuningConfig const& self) { return nb::make_tuple(self.getEmbeddingTable(), self.getInputTokenExtraIds()); }; auto promptTuningConfigSetstate = [](tle::PromptTuningConfig& promptTuningConfig, nb::tuple const& state) { if (state.size() != 2) { throw std::runtime_error("Invalid PromptTuningConfig state!"); } new (&promptTuningConfig) tle::PromptTuningConfig(nb::cast(state[0]), nb::cast>(state[1])); }; nb::class_(m, "PromptTuningConfig") .def(nb::init>(), nb::arg("embedding_table"), nb::arg("input_token_extra_ids") = nb::none()) .def_prop_ro("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable) .def_prop_ro("input_token_extra_ids", &tle::PromptTuningConfig::getInputTokenExtraIds) .def("__getstate__", promptTuningConfigGetstate) .def("__setstate__", promptTuningConfigSetstate); auto loraConfigGetstate = [](tle::LoraConfig const& self) { return nb::make_tuple(self.getTaskId(), self.getWeights(), self.getConfig()); }; auto loraConfigSetstate = [](tle::LoraConfig& loraConfig, nb::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid LoraConfig state!"); } new (&loraConfig) tle::LoraConfig(nb::cast(state[0]), nb::cast>(state[1]), nb::cast>(state[2])); }; nb::class_(m, "LoraConfig") .def(nb::init, std::optional>(), nb::arg("task_id"), nb::arg("weights") = nb::none(), nb::arg("config") = nb::none()) .def_prop_ro("task_id", &tle::LoraConfig::getTaskId) .def_prop_ro("weights", &tle::LoraConfig::getWeights) .def_prop_ro("config", &tle::LoraConfig::getConfig) .def("__getstate__", loraConfigGetstate) .def("__setstate__", loraConfigSetstate); auto multimodalInputGetstate = [](tle::MultimodalInput const& self) { return nb::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths()); }; auto multimodalInputSetstate = [](tle::MultimodalInput& multimodalInput, nb::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid MultimodalInput state!"); } new (&multimodalInput) tle::MultimodalInput(nb::cast>>(state[0]), nb::cast>(state[1]), nb::cast>(state[2])); }; nb::class_(m, "MultimodalInput") .def(nb::init>, std::vector, std::vector>(), nb::arg("multimodal_hashes"), nb::arg("multimodal_positions"), nb::arg("multimodal_lengths")) .def_prop_ro("multimodal_hashes", &tle::MultimodalInput::getMultimodalHashes) .def_prop_ro("multimodal_positions", &tle::MultimodalInput::getMultimodalPositions) .def_prop_ro("multimodal_lengths", &tle::MultimodalInput::getMultimodalLengths) .def("__getstate__", multimodalInputGetstate) .def("__setstate__", multimodalInputSetstate); auto MropeConfigGetstate = [](tle::MropeConfig const& self) { return nb::make_tuple(self.getMRopeRotaryCosSin(), self.getMRopePositionDeltas()); }; auto MropeConfigSetstate = [](tle::MropeConfig& mropeConfig, nb::tuple const& state) { if (state.size() != 2) { throw std::runtime_error("Invalid MropeConfig state!"); } new (&mropeConfig) tle::MropeConfig(nb::cast(state[0]), nb::cast(state[1])); }; nb::class_(m, "MropeConfig") .def(nb::init(), nb::arg("mrope_rotary_cos_sin"), nb::arg("mrope_position_deltas")) .def_prop_ro("mrope_rotary_cos_sin", &tle::MropeConfig::getMRopeRotaryCosSin) .def_prop_ro("mrope_position_deltas", &tle::MropeConfig::getMRopePositionDeltas) .def("__getstate__", MropeConfigGetstate) .def("__setstate__", MropeConfigSetstate); auto lookaheadDecodingConfigGetstate = [](tle::LookaheadDecodingConfig const& self) { return nb::make_tuple(self.getWindowSize(), self.getNgramSize(), self.getVerificationSetSize()); }; auto lookaheadDecodingConfigSetstate = [](tle::LookaheadDecodingConfig& lookaheadDecodingConfig, nb::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid LookaheadDecodingConfig state!"); } new (&lookaheadDecodingConfig) tle::LookaheadDecodingConfig( nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])); }; nb::class_(m, "LookaheadDecodingConfig") .def(nb::init(), nb::arg("max_window_size"), nb::arg("max_ngram_size"), nb::arg("max_verification_set_size")) .def_prop_ro("max_window_size", &tle::LookaheadDecodingConfig::getWindowSize) .def_prop_ro("max_ngram_size", &tle::LookaheadDecodingConfig::getNgramSize) .def_prop_ro("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize) .def("calculate_speculative_resource", &tle::LookaheadDecodingConfig::calculateSpeculativeResource) .def_static( "calculate_speculative_resource_tuple", &tle::LookaheadDecodingConfig::calculateSpeculativeResourceTuple) .def("__getstate__", lookaheadDecodingConfigGetstate) .def("__setstate__", lookaheadDecodingConfigSetstate) .def_static("get_default_lookahead_decoding_window", []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingWindow; }) .def_static("get_default_lookahead_decoding_ngram", []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingNgram; }) .def_static("get_default_lookahead_decoding_verification_set", []() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingVerificationSet; }); auto TokenRangeRetentionConfigGetstate = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig const& self) { return nb::make_tuple(self.tokenStart, self.tokenEnd, self.priority, self.durationMs); }; auto TokenRangeRetentionConfigSetstate = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig& tokenRangeRetentionConfig, nb::tuple const& state) { if (state.size() != 4) { throw std::runtime_error("Invalid state!"); } new (&tokenRangeRetentionConfig) tle::KvCacheRetentionConfig::TokenRangeRetentionConfig( nb::cast(state[0]), nb::cast>(state[1]), nb::cast(state[2]), nb::cast>(state[3])); }; auto kvCacheRetentionConfigGetstate = [](tle::KvCacheRetentionConfig const& self) { return nb::make_tuple(self.getTokenRangeRetentionConfigs(), self.getDecodeRetentionPriority(), self.getDecodeDurationMs(), self.getTransferMode(), self.getDirectory()); }; auto kvCacheRetentionConfigSetstate = [](tle::KvCacheRetentionConfig& kvCacheRetentionConfig, nb::tuple const& state) { if (state.size() != 5) { throw std::runtime_error("Invalid state!"); } new (&kvCacheRetentionConfig) tle::KvCacheRetentionConfig( nb::cast>(state[0]), nb::cast(state[1]), nb::cast>(state[2]), nb::cast(state[3]), nb::cast(state[4])); }; auto kvCacheRetentionConfig = nb::class_(m, "KvCacheRetentionConfig"); nb::class_( kvCacheRetentionConfig, "TokenRangeRetentionConfig") .def(nb::init, tle::RetentionPriority, std::optional>(), nb::arg("token_start"), nb::arg("token_end"), nb::arg("priority"), nb::arg("duration_ms") = nb::none()) .def_rw("token_start", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenStart) .def_rw("token_end", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenEnd) .def_rw("priority", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::priority) .def_rw("duration_ms", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::durationMs) .def("__getstate__", TokenRangeRetentionConfigGetstate) .def("__setstate__", TokenRangeRetentionConfigSetstate) .def("__eq__", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::operator==); // There's a circular dependency between the declaration of the TokenRangeRetentionPriority and // KvCacheRetentionConfig bindings. Defer definition of the KvCacheRetentionConfig bindings until the // TokenRangeRetentionPriority bindings have been defined. kvCacheRetentionConfig .def(nb::init, tle::RetentionPriority, std::optional, tle::KvCacheTransferMode, std::string>(), nb::arg("token_range_retention_configs"), nb::arg("decode_retention_priority") = tle::KvCacheRetentionConfig::kDefaultRetentionPriority, nb::arg("decode_duration_ms") = nb::none(), nb::arg("transfer_mode") = tle::KvCacheTransferMode::DRAM, nb::arg("directory") = nb::none()) .def_prop_ro("token_range_retention_configs", &tle::KvCacheRetentionConfig::getTokenRangeRetentionConfigs) .def_prop_ro("decode_retention_priority", &tle::KvCacheRetentionConfig::getDecodeRetentionPriority) .def_prop_ro("decode_duration_ms", &tle::KvCacheRetentionConfig::getDecodeDurationMs) .def_prop_ro("transfer_mode", &tle::KvCacheRetentionConfig::getTransferMode) .def_prop_ro("directory", &tle::KvCacheRetentionConfig::getDirectory) .def("__getstate__", kvCacheRetentionConfigGetstate) .def("__setstate__", kvCacheRetentionConfigSetstate) .def("__eq__", &tle::KvCacheRetentionConfig::operator==); auto ContextPhaseParamsGetState = [](tle::ContextPhaseParams const& self) { if (self.getState() != nullptr) { auto serializedState = self.getSerializedState(); return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens()); } return nb::make_tuple(self.getFirstGenTokens(), self.getReqId(), nb::none(), self.getDraftTokens()); }; auto ContextPhaseParamsSetState = [](tle::ContextPhaseParams& contextPhaseParams, nb::tuple const& state) { if (state.size() != 4) { throw std::runtime_error("Invalid ContextPhaseParams state!"); } if (!state[2].is_none()) { auto opaque_state = nb::cast(state[2]); auto opaque_state_str_view = std::string_view(opaque_state.c_str(), opaque_state.size()); new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), nb::cast(state[1]), std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), nb::cast>(state[3])); } else { new (&contextPhaseParams) tle::ContextPhaseParams(nb::cast(state[0]), nb::cast(state[1]), nb::cast>(state[3])); } }; nb::class_(m, "ContextPhaseParams") .def( "__init__", [](tle::ContextPhaseParams& self, VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id, std::optional const& opaque_state, std::optional const& draft_tokens) { if (opaque_state) { auto opaque_state_str_view = std::string_view(opaque_state.value().c_str(), opaque_state.value().size()); new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, std::vector(opaque_state_str_view.begin(), opaque_state_str_view.end()), draft_tokens); } else { new (&self) tle::ContextPhaseParams(first_gen_tokens, req_id, draft_tokens); } }, nb::arg("first_gen_tokens"), nb::arg("req_id"), nb::arg("opaque_state").none(), nb::arg("draft_tokens").none()) .def_prop_ro("first_gen_tokens", [](tle::ContextPhaseParams const& self) { return self.getFirstGenTokens(); }) .def_prop_ro("draft_tokens", [](tle::ContextPhaseParams const& self) { return self.getDraftTokens(); }) .def_prop_ro("req_id", &tle::ContextPhaseParams::getReqId) .def_prop_ro("opaque_state", [](tle::ContextPhaseParams const& self) { std::optional opaque_state{std::nullopt}; if (self.getState() != nullptr) { auto serializedState = self.getSerializedState(); opaque_state = nb::bytes(serializedState.data(), serializedState.size()); } return opaque_state; }) .def("__getstate__", ContextPhaseParamsGetState) .def("__setstate__", ContextPhaseParamsSetState); auto EagleDecodingConfigGetstate = [](tle::EagleConfig const& self) { return nb::make_tuple(self.getEagleChoices(), self.isGreedySampling(), self.getPosteriorThreshold(), self.useDynamicTree(), self.getDynamicTreeMaxTopK()); }; auto EagleDecodingConfigSetstate = [](tle::EagleConfig& self, nb::tuple const& state) { if (state.size() != 5) { throw std::runtime_error("Invalid EagleConfig state!"); } new (&self) tle::EagleConfig(nb::cast>(state[0]), nb::cast(state[1]), nb::cast>(state[2]), nb::cast(state[3]), nb::cast>(state[4])); }; nb::class_(m, "EagleConfig") .def(nb::init, bool, std::optional, bool, std::optional>(), nb::arg("eagle_choices") = nb::none(), nb::arg("greedy_sampling") = true, nb::arg("posterior_threshold") = nb::none(), nb::arg("use_dynamic_tree") = false, nb::arg("dynamic_tree_max_topK") = nb::none()) .def_prop_ro("eagle_choices", &tle::EagleConfig::getEagleChoices) .def_prop_ro("greedy_sampling", &tle::EagleConfig::isGreedySampling) .def_prop_ro("posterior_threshold", &tle::EagleConfig::getPosteriorThreshold) .def_prop_ro("use_dynamic_tree", &tle::EagleConfig::useDynamicTree) .def_prop_ro("dynamic_tree_max_topK", &tle::EagleConfig::getDynamicTreeMaxTopK) .def("__getstate__", EagleDecodingConfigGetstate) .def("__setstate__", EagleDecodingConfigSetstate); // Guided decoding params auto pyGuidedDecodingParams = nb::class_(m, "GuidedDecodingParams"); nb::enum_(pyGuidedDecodingParams, "GuideType") .value("JSON", tle::GuidedDecodingParams::GuideType::kJSON) .value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA) .value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX) .value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR) .value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG); auto guidedDecodingParamsGetstate = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide()); }; auto guidedDecodingParamsSetstate = [](tle::GuidedDecodingParams& self, nb::tuple const& state) { if (state.size() != 2) { throw std::runtime_error("Invalid GuidedDecodingParams state!"); } new (&self) tle::GuidedDecodingParams( nb::cast(state[0]), nb::cast>(state[1])); }; pyGuidedDecodingParams .def(nb::init>(), nb::arg("guide_type"), nb::arg("guide") = nb::none()) .def_prop_ro("guide_type", &tle::GuidedDecodingParams::getGuideType) .def_prop_ro("guide", &tle::GuidedDecodingParams::getGuide) .def("__getstate__", guidedDecodingParamsGetstate) .def("__setstate__", guidedDecodingParamsSetstate); auto requestGetstate = [](tle::Request const& self) { return nb::make_tuple(self.getInputTokenIds(), self.getMaxTokens(), self.getStreaming(), self.getSamplingConfig(), self.getOutputConfig(), self.getEndId(), self.getPadId(), self.getPositionIds(), self.getBadWords(), self.getStopWords(), self.getEmbeddingBias(), self.getExternalDraftTokensConfig(), self.getPromptTuningConfig(), self.getMultimodalInput(), self.getMultimodalEmbedding(), self.getMropeConfig(), self.getLoraConfig(), self.getLookaheadConfig(), self.getKvCacheRetentionConfig(), self.getLogitsPostProcessorName(), self.getLogitsPostProcessor(), self.getEncoderInputTokenIds(), self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), self.getGuidedDecodingParams(), self.getCacheSaltID()); }; auto requestSetstate = [](tle::Request& self, nb::tuple const& state) { if (state.size() != 34) { throw std::runtime_error("Invalid Request state!"); } new (&self) tle::Request(nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), nb::cast>(state[5]), nb::cast>(state[6]), nb::cast>>(state[7]), nb::cast>>(state[8]), nb::cast>>(state[9]), nb::cast>(state[10]), nb::cast>(state[11]), nb::cast>(state[12]), nb::cast>(state[13]), nb::cast>(state[14]), nb::cast>(state[15]), nb::cast>(state[16]), nb::cast>(state[17]), nb::cast>(state[18]), nb::cast>(state[19]), nb::cast>(state[20]), nb::cast>(state[21]), nb::cast>(state[22]), nb::cast(state[23]), nb::cast(state[24]), nb::cast(state[25]), nb::cast>(state[26]), nb::cast>(state[27]), nb::cast>(state[28]), nb::cast>(state[29]), 1, nb::cast>(state[30]), nb::cast>(state[31]), nb::cast>(state[32]), nb::cast>(state[33])); }; nb::class_ request(m, "Request", nb::dynamic_attr()); request .def(nb::init const&, // endId std::optional const&, // padId std::optional>, // positionIds std::optional>, // badWords std::optional>, // stopWords std::optional, // embeddingBias std::optional, // externalDraftTokensConfig std::optional, // pTuningConfig std::optional, // multimodalInput std::optional, // multimodalEmbedding std::optional, // mRopeConfig std::optional, // loraConfig std::optional, // lookaheadConfig std::optional, // kvCacheRetentionConfig std::optional, // logitsPostProcessorName std::optional, // logitsPostProcessor std::optional, // encoderInputTokenIds std::optional, // clientId bool, // returnAllGeneratedTokens tle::PriorityType, // priority tle::RequestType, // type std::optional, // contextPhaseParams std::optional, // encoderInputFeatures std::optional, // encoderOutputLength std::optional, // crossAttentionMask SizeType32, // numReturnSequences std::optional, // eagleConfig std::optional, // skipCrossAttnBlocks std::optional, // guidedDecodingParams std::optional, // languageAdapterUid std::optional, // allottedTimeMs std::optional // cacheSaltID >(), // clang-format off nb::arg("input_token_ids"), nb::arg("max_tokens"), nb::kw_only(), nb::arg("streaming") = false, nb::arg("sampling_config") = tle::SamplingConfig(), nb::arg("output_config") = tle::OutputConfig(), nb::arg("end_id") = nb::none(), nb::arg("pad_id") = nb::none(), nb::arg("position_ids") = nb::none(), nb::arg("bad_words") = nb::none(), nb::arg("stop_words") = nb::none(), nb::arg("embedding_bias") = nb::none(), nb::arg("external_draft_tokens_config") = nb::none(), nb::arg("prompt_tuning_config") = nb::none(), nb::arg("multimodal_input") = nb::none(), nb::arg("multimodal_embedding") = nb::none(), nb::arg("mrope_config") = nb::none(), nb::arg("lora_config") = nb::none(), nb::arg("lookahead_config") = nb::none(), nb::arg("kv_cache_retention_config") = nb::none(), nb::arg("logits_post_processor_name") = nb::none(), nb::arg("logits_post_processor") = nb::none(), nb::arg("encoder_input_token_ids") = nb::none(), nb::arg("client_id") = nb::none(), nb::arg("return_all_generated_tokens") = false, nb::arg("priority") = tle::Request::kDefaultPriority, nb::arg("type") = tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION, nb::arg("context_phase_params") = nb::none(), nb::arg("encoder_input_features") = nb::none(), nb::arg("encoder_output_length") = nb::none(), nb::arg("cross_attention_mask") = nb::none(), nb::arg("num_return_sequences") = 1, nb::arg("eagle_config") = nb::none(), nb::arg("skip_cross_attn_blocks") = nb::none(), nb::arg("guided_decoding_params") = nb::none(), nb::arg("language_adapter_uid") = nb::none(), nb::arg("allotted_time_ms") = nb::none(), nb::arg("cache_salt_id") = nb::none() ) // clang-format on .def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds) .def_prop_ro("max_tokens", &tle::Request::getMaxTokens) .def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) .def_prop_rw("sampling_config", &tle::Request::getSamplingConfig, &tle::Request::setSamplingConfig) .def_prop_rw("output_config", &tle::Request::getOutputConfig, &tle::Request::setOutputConfig) .def_prop_rw("end_id", &tle::Request::getEndId, &tle::Request::setEndId) .def_prop_rw("pad_id", &tle::Request::getPadId, &tle::Request::setPadId) .def_prop_rw("position_ids", &tle::Request::getPositionIds, &tle::Request::setPositionIds) .def_prop_rw("bad_words", &tle::Request::getBadWords, &tle::Request::setBadWords) .def_prop_rw("stop_words", &tle::Request::getStopWords, &tle::Request::setStopWords) .def_prop_rw("embedding_bias", &tle::Request::getEmbeddingBias, &tle::Request::setEmbeddingBias) .def_prop_rw("external_draft_tokens_config", &tle::Request::getExternalDraftTokensConfig, &tle::Request::setExternalDraftTokensConfig) .def_prop_rw("prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig) .def_prop_rw("multimodal_input", &tle::Request::getMultimodalInput, &tle::Request::setMultimodalInput) .def_prop_rw( "multimodal_embedding", &tle::Request::getMultimodalEmbedding, &tle::Request::setMultimodalEmbedding) .def_prop_rw("mrope_config", &tle::Request::getMropeConfig, &tle::Request::setMropeConfig) .def_prop_rw("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig) .def_prop_rw("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig) .def_prop_rw("kv_cache_retention_config", &tle::Request::getKvCacheRetentionConfig, &tle::Request::setKvCacheRetentionConfig) .def_prop_rw("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName, &tle::Request::setLogitsPostProcessorName) .def_prop_rw( "logits_post_processor", &tle::Request::getLogitsPostProcessor, &tle::Request::setLogitsPostProcessor) .def_prop_rw( "encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds) .def_prop_rw("client_id", &tle::Request::getClientId, &tle::Request::setClientId) .def_prop_rw("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens, &tle::Request::setReturnAllGeneratedTokens) .def_prop_rw("request_type", &tle::Request::getRequestType, &tle::Request::setRequestType) .def_prop_rw( "encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures) .def_prop_rw("cross_attention_mask", &tle::Request::getCrossAttentionMask, &tle::Request::setCrossAttentionMask) .def_prop_rw("eagle_config", &tle::Request::getEagleConfig, &tle::Request::setEagleConfig) .def_prop_rw( "skip_cross_attn_blocks", &tle::Request::getSkipCrossAttnBlocks, &tle::Request::setSkipCrossAttnBlocks) .def_prop_rw( "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) .def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) .def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID) .def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) .def("__getstate__", requestGetstate) .def("__setstate__", requestSetstate); request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName; nb::class_(m, "SpeculativeDecodingFastLogitsInfo") .def(nb::init<>()) .def_rw("draft_request_id", &tle::SpeculativeDecodingFastLogitsInfo::draftRequestId) .def_rw("draft_participant_id", &tle::SpeculativeDecodingFastLogitsInfo::draftParticipantId) .def("to_tensor", &tle::SpeculativeDecodingFastLogitsInfo::toTensor); auto requestPerfMetrics = nb::class_(m, "RequestPerfMetrics"); auto timingMetricsGetstate = [](tle::RequestPerfMetrics::TimingMetrics const& self) { return nb::make_tuple(self.arrivalTime, self.firstScheduledTime, self.firstTokenTime, self.lastTokenTime, self.kvCacheTransferStart, self.kvCacheTransferEnd, self.kvCacheSize); }; auto timingMetricsSetstate = [](tle::RequestPerfMetrics::TimingMetrics& timingMetrics, nb::tuple const& state) { if (state.size() != 7) { throw std::runtime_error("Invalid TimingMetrics state!"); } new (&timingMetrics) tle::RequestPerfMetrics::TimingMetrics{nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4]), nb::cast(state[5]), nb::cast(state[6])}; }; nb::class_(m, "TimingMetrics") .def(nb::init<>()) .def_rw("arrival_time", &tle::RequestPerfMetrics::TimingMetrics::arrivalTime) .def_rw("first_scheduled_time", &tle::RequestPerfMetrics::TimingMetrics::firstScheduledTime) .def_rw("first_token_time", &tle::RequestPerfMetrics::TimingMetrics::firstTokenTime) .def_rw("last_token_time", &tle::RequestPerfMetrics::TimingMetrics::lastTokenTime) .def_rw("kv_cache_transfer_start", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferStart) .def_rw("kv_cache_transfer_end", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferEnd) .def_rw("kv_cache_size", &tle::RequestPerfMetrics::TimingMetrics::kvCacheSize) .def("__getstate__", timingMetricsGetstate) .def("__setstate__", timingMetricsSetstate); auto kvCacheMetricsGetstate = [](tle::RequestPerfMetrics::KvCacheMetrics const& self) { return nb::make_tuple(self.numTotalAllocatedBlocks, self.numNewAllocatedBlocks, self.numReusedBlocks, self.numMissedBlocks, self.kvCacheHitRate); }; auto kvCacheMetricsSetstate = [](tle::RequestPerfMetrics::KvCacheMetrics& kvCacheMetrics, nb::tuple const& state) { if (state.size() != 5) { throw std::runtime_error("Invalid KvCacheMetrics state!"); } new (&kvCacheMetrics) tle::RequestPerfMetrics::KvCacheMetrics{nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2]), nb::cast(state[3]), nb::cast(state[4])}; }; nb::class_(m, "KvCacheMetrics") .def(nb::init<>()) .def_rw("num_total_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numTotalAllocatedBlocks) .def_rw("num_new_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numNewAllocatedBlocks) .def_rw("num_reused_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numReusedBlocks) .def_rw("num_missed_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numMissedBlocks) .def_rw("kv_cache_hit_rate", &tle::RequestPerfMetrics::KvCacheMetrics::kvCacheHitRate) .def("__getstate__", kvCacheMetricsGetstate) .def("__setstate__", kvCacheMetricsSetstate); auto speculativeDecodingMetricsGetstate = [](tle::RequestPerfMetrics::SpeculativeDecodingMetrics const& self) { return nb::make_tuple(self.acceptanceRate, self.totalAcceptedDraftTokens, self.totalDraftTokens); }; auto speculativeDecodingMetricsSetstate = [](tle::RequestPerfMetrics::SpeculativeDecodingMetrics& speculativeDecodingMetrics, nb::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid SpeculativeDecodingMetrics state!"); } new (&speculativeDecodingMetrics) tle::RequestPerfMetrics::SpeculativeDecodingMetrics{ nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])}; }; nb::class_(m, "SpeculativeDecodingMetrics") .def(nb::init<>()) .def_rw("acceptance_rate", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::acceptanceRate) .def_rw("total_accepted_draft_tokens", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalAcceptedDraftTokens) .def_rw("total_draft_tokens", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalDraftTokens) .def("__getstate__", speculativeDecodingMetricsGetstate) .def("__setstate__", speculativeDecodingMetricsSetstate); auto requestPerfMetricsGetstate = [](tle::RequestPerfMetrics const& self) { return nb::make_tuple(self.timingMetrics, self.kvCacheMetrics, self.speculativeDecoding, self.firstIter, self.lastIter, self.iter); }; auto requestPerfMetricsSetstate = [](tle::RequestPerfMetrics& self, nb::tuple const& state) { if (state.size() != 6) { throw std::runtime_error("Invalid RequestPerfMetrics state!"); } new (&self) tle::RequestPerfMetrics{nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2]), nb::cast>(state[3]), nb::cast>(state[4]), nb::cast>(state[5])}; }; // There's a circular dependency between the declaration of the TimingMetrics and RequestPerfMetrics bindings. // Defer definition of the RequestPerfMetrics bindings until the TimingMetrics have been defined. requestPerfMetrics.def(nb::init<>()) .def_rw("timing_metrics", &tle::RequestPerfMetrics::timingMetrics) .def_rw("kv_cache_metrics", &tle::RequestPerfMetrics::kvCacheMetrics) .def_rw("speculative_decoding", &tle::RequestPerfMetrics::speculativeDecoding) .def_rw("first_iter", &tle::RequestPerfMetrics::firstIter) .def_rw("last_iter", &tle::RequestPerfMetrics::lastIter) .def_rw("iter", &tle::RequestPerfMetrics::iter) .def("__getstate__", requestPerfMetricsGetstate) .def("__setstate__", requestPerfMetricsSetstate); nb::class_(m, "AdditionalOutput") .def(nb::init(), nb::arg("name"), nb::arg("output")) .def_rw("name", &tle::AdditionalOutput::name) .def_rw("output", &tle::AdditionalOutput::output); auto resultSetstate = [](tle::Result& self, nb::tuple const& state) { if (state.size() != 14) { throw std::runtime_error("Invalid Request state!"); } tle::Result result; result.isFinal = nb::cast(state[0]); result.outputTokenIds = nb::cast>(state[1]); result.cumLogProbs = nb::cast>>(state[2]); result.logProbs = nb::cast>>>(state[3]); result.contextLogits = nb::cast>(state[4]); result.generationLogits = nb::cast>(state[5]); result.encoderOutput = nb::cast>(state[6]); result.finishReasons = nb::cast>(state[7]); result.sequenceIndex = nb::cast(state[8]); result.isSequenceFinal = nb::cast(state[9]); result.decodingIter = nb::cast(state[10]); result.avgDecodedTokensPerIter = nb::cast(state[11]); result.contextPhaseParams = nb::cast>(state[12]); result.requestPerfMetrics = nb::cast>(state[13]); new (&self) tle::Result(result); }; auto resultGetstate = [](tle::Result const& self) { return nb::make_tuple(self.isFinal, self.outputTokenIds, self.cumLogProbs, self.logProbs, self.contextLogits, self.generationLogits, self.encoderOutput, self.finishReasons, self.sequenceIndex, self.isSequenceFinal, self.decodingIter, self.avgDecodedTokensPerIter, self.contextPhaseParams, self.requestPerfMetrics); }; nb::class_(m, "Result") .def(nb::init<>()) .def_rw("is_final", &tle::Result::isFinal) .def_rw("output_token_ids", &tle::Result::outputTokenIds) .def_rw("cum_log_probs", &tle::Result::cumLogProbs, nb::arg("cum_log_probs").none()) .def_rw("log_probs", &tle::Result::logProbs, nb::arg("log_probs").none()) .def_rw("context_logits", &tle::Result::contextLogits, nb::arg("context_logits").none()) .def_rw("generation_logits", &tle::Result::generationLogits, nb::arg("generation_logits").none()) .def_rw("spec_dec_fast_logits_info", &tle::Result::specDecFastLogitsInfo, nb::arg("spec_dec_fast_logits_info").none()) .def_rw("encoder_output", &tle::Result::encoderOutput, nb::arg("encoder_output").none()) .def_rw("finish_reasons", &tle::Result::finishReasons) .def_rw("sequence_index", &tle::Result::sequenceIndex) .def_rw("is_sequence_final", &tle::Result::isSequenceFinal) .def_rw("decoding_iter", &tle::Result::decodingIter) .def_rw("avg_decoded_tokens_per_iter", &tle::Result::avgDecodedTokensPerIter) .def_rw("context_phase_params", &tle::Result::contextPhaseParams, nb::arg("context_phase_params").none()) .def_rw("request_perf_metrics", &tle::Result::requestPerfMetrics, nb::arg("request_perf_metrics").none()) .def_rw("additional_outputs", &tle::Result::additionalOutputs) .def("__getstate__", resultGetstate) .def("__setstate__", resultSetstate); m.def("deserialize_result", [](nb::bytes& x) { std::string str(x.c_str(), x.size()); std::istringstream is(str); return tle::serialize_utils::deserialize(is); }); auto responseGetstate = [](tle::Response const& self) { return nb::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); }; auto responseSetstate = [](tle::Response& response, nb::tuple const& state) { if (state.size() != 3) { throw std::runtime_error("Invalid Request state!"); } new (&response) tle::Response( nb::cast(state[0]), nb::cast(state[1]), nb::cast(state[2])); }; nb::class_(m, "Response") .def(nb::init>(), nb::arg("request_id"), nb::arg("error_msg"), nb::arg("client_id") = std::nullopt) .def(nb::init>(), nb::arg("request_id"), nb::arg("result"), nb::arg("client_id") = std::nullopt) .def_prop_ro("request_id", &tle::Response::getRequestId) .def_prop_ro("client_id", &tle::Response::getClientId) .def("has_error", &tle::Response::hasError) .def_prop_ro("error_msg", &tle::Response::getErrorMsg) .def_prop_ro("result", &tle::Response::getResult) .def("clear_context_logits", [](tle::Response& self) { if (!self.hasError()) { auto& result = const_cast(self.getResult()); result.contextLogits.reset(); } }) .def("clear_generation_logits", [](tle::Response& self) { if (!self.hasError()) { auto& result = const_cast(self.getResult()); result.generationLogits.reset(); } }) .def("__getstate__", responseGetstate) .def("__setstate__", responseSetstate); } } // namespace tensorrt_llm::nanobind::executor