mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
926 lines
56 KiB
C++
926 lines
56 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "request.h"
|
|
#include "tensorrt_llm/common/assert.h"
|
|
#include "tensorrt_llm/common/logger.h"
|
|
#include "tensorrt_llm/executor/executor.h"
|
|
#include "tensorrt_llm/executor/serializeUtils.h"
|
|
#include "tensorrt_llm/executor/tensor.h"
|
|
#include "tensorrt_llm/executor/types.h"
|
|
#include "tensorrt_llm/runtime/cudaStream.h"
|
|
|
|
#include <pybind11/cast.h>
|
|
#include <pybind11/chrono.h>
|
|
#include <pybind11/functional.h>
|
|
#include <pybind11/operators.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/stl.h>
|
|
#include <sstream>
|
|
|
|
#include <optional>
|
|
#include <vector>
|
|
|
|
namespace py = pybind11;
|
|
namespace tle = tensorrt_llm::executor;
|
|
using Tensor = tle::Tensor;
|
|
using SizeType32 = tle::SizeType32;
|
|
using FloatType = tle::FloatType;
|
|
using VecTokens = tle::VecTokens;
|
|
using IdType = tle::IdType;
|
|
using VecTokenExtraIds = tle::VecTokenExtraIds;
|
|
|
|
namespace tensorrt_llm::pybind::executor
|
|
{
|
|
|
|
void initRequestBindings(pybind11::module_& m)
|
|
{
|
|
py::enum_<tle::RequestType>(m, "RequestType")
|
|
.value("REQUEST_TYPE_CONTEXT_AND_GENERATION", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION)
|
|
.value("REQUEST_TYPE_CONTEXT_ONLY", tle::RequestType::REQUEST_TYPE_CONTEXT_ONLY)
|
|
.value("REQUEST_TYPE_GENERATION_ONLY", tle::RequestType::REQUEST_TYPE_GENERATION_ONLY);
|
|
|
|
py::enum_<tle::FinishReason>(m, "FinishReason")
|
|
.value("NOT_FINISHED", tle::FinishReason::kNOT_FINISHED)
|
|
.value("END_ID", tle::FinishReason::kEND_ID)
|
|
.value("STOP_WORDS", tle::FinishReason::kSTOP_WORDS)
|
|
.value("LENGTH", tle::FinishReason::kLENGTH)
|
|
.value("TIMED_OUT", tle::FinishReason::kTIMED_OUT)
|
|
.value("CANCELLED", tle::FinishReason::kCANCELLED);
|
|
|
|
py::enum_<tle::KvCacheTransferMode>(m, "KvCacheTransferMode")
|
|
.value("DRAM", tle::KvCacheTransferMode::DRAM)
|
|
.value("GDS", tle::KvCacheTransferMode::GDS)
|
|
.value("POSIX_DEBUG_FALLBACK", tle::KvCacheTransferMode::POSIX_DEBUG_FALLBACK);
|
|
|
|
auto samplingConfigGetstate = [](tle::SamplingConfig const& self)
|
|
{
|
|
return py::make_tuple(self.getBeamWidth(), self.getTopK(), self.getTopP(), self.getTopPMin(),
|
|
self.getTopPResetIds(), self.getTopPDecay(), self.getSeed(), self.getTemperature(), self.getMinTokens(),
|
|
self.getBeamSearchDiversityRate(), self.getRepetitionPenalty(), self.getPresencePenalty(),
|
|
self.getFrequencyPenalty(), self.getPromptIgnoreLength(), self.getLengthPenalty(), self.getEarlyStopping(),
|
|
self.getNoRepeatNgramSize(), self.getNumReturnSequences(), self.getMinP(), self.getBeamWidthArray());
|
|
};
|
|
auto samplingConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 20)
|
|
{
|
|
throw std::runtime_error("Invalid SamplingConfig state!");
|
|
}
|
|
return tle::SamplingConfig(state[0].cast<SizeType32>(), // BeamWidth
|
|
state[1].cast<std::optional<SizeType32>>(), // TopK
|
|
state[2].cast<std::optional<FloatType>>(), // TopP
|
|
state[3].cast<std::optional<FloatType>>(), // TopPMin
|
|
state[4].cast<std::optional<tle::TokenIdType>>(), // TopPResetIds
|
|
state[5].cast<std::optional<FloatType>>(), // TopPDecay
|
|
state[6].cast<std::optional<tle::RandomSeedType>>(), // Seed
|
|
state[7].cast<std::optional<FloatType>>(), // Temperature
|
|
state[8].cast<std::optional<SizeType32>>(), // MinTokens
|
|
state[9].cast<std::optional<FloatType>>(), // BeamSearchDiversityRate
|
|
state[10].cast<std::optional<FloatType>>(), // RepetitionPenalty
|
|
state[11].cast<std::optional<FloatType>>(), // PresencePenalty
|
|
state[12].cast<std::optional<FloatType>>(), // FrequencyPenalty
|
|
state[13].cast<std::optional<SizeType32>>(), // PromptIgnoreLength
|
|
state[14].cast<std::optional<FloatType>>(), // LengthPenalty
|
|
state[15].cast<std::optional<SizeType32>>(), // EarlyStopping
|
|
state[16].cast<std::optional<SizeType32>>(), // NoRepeatNgramSize
|
|
state[17].cast<std::optional<SizeType32>>(), // NumReturnSequences
|
|
state[18].cast<std::optional<FloatType>>(), // MinP
|
|
state[19].cast<std::optional<std::vector<SizeType32>>>() // BeamWidthArray
|
|
);
|
|
};
|
|
py::class_<tle::SamplingConfig>(m, "SamplingConfig")
|
|
.def(py::init<tle::SizeType32,
|
|
std::optional<tle::SizeType32> const&, // beamWidth
|
|
std::optional<tle::FloatType> const&, // topP
|
|
std::optional<tle::FloatType> const&, // topPMin
|
|
std::optional<tle::TokenIdType> const&, // topPResetIds
|
|
std::optional<tle::FloatType> const&, // topPDecay
|
|
std::optional<tle::RandomSeedType> const&, // seed
|
|
std::optional<tle::FloatType> const&, // temperature
|
|
std::optional<tle::SizeType32> const&, // minTokens
|
|
std::optional<tle::FloatType> const&, // beamSearchDiversityRate
|
|
std::optional<tle::FloatType> const&, // repetitionPenalty
|
|
std::optional<tle::FloatType> const&, // presencePenalty
|
|
std::optional<tle::FloatType> const&, // frequencyPenalty
|
|
std::optional<tle::SizeType32> const&, // promptIgnoreLength
|
|
std::optional<tle::FloatType> const&, // lengthPenalty
|
|
std::optional<tle::SizeType32> const&, // earlyStopping
|
|
std::optional<tle::SizeType32> const&, // noRepeatNgramSize
|
|
std::optional<tle::SizeType32> const&, // numReturnSequences
|
|
std::optional<tle::FloatType> const&, // minP
|
|
std::optional<std::vector<tle::SizeType32>> const& // beamWidthArray
|
|
>(),
|
|
// clang-format off
|
|
py::arg("beam_width") = 1,
|
|
py::kw_only(),
|
|
py::arg("top_k") = py::none(),
|
|
py::arg("top_p") = py::none(),
|
|
py::arg("top_p_min") = py::none(),
|
|
py::arg("top_p_reset_ids") = py::none(),
|
|
py::arg("top_p_decay") = py::none(),
|
|
py::arg("seed") = py::none(),
|
|
py::arg("temperature") = py::none(),
|
|
py::arg("min_tokens") = py::none(),
|
|
py::arg("beam_search_diversity_rate") = py::none(),
|
|
py::arg("repetition_penalty") = py::none(),
|
|
py::arg("presence_penalty") = py::none(),
|
|
py::arg("frequency_penalty") = py::none(),
|
|
py::arg("prompt_ignore_length") = py::none(),
|
|
py::arg("length_penalty") = py::none(),
|
|
py::arg("early_stopping") = py::none(),
|
|
py::arg("no_repeat_ngram_size") = py::none(),
|
|
py::arg("num_return_sequences") = py::none(),
|
|
py::arg("min_p") = py::none(),
|
|
py::arg("beam_width_array") = py::none()) // clang-format on
|
|
.def_property("beam_width", &tle::SamplingConfig::getBeamWidth, &tle::SamplingConfig::setBeamWidth)
|
|
.def_property("top_k", &tle::SamplingConfig::getTopK, &tle::SamplingConfig::setTopK)
|
|
.def_property("top_p", &tle::SamplingConfig::getTopP, &tle::SamplingConfig::setTopP)
|
|
.def_property("top_p_min", &tle::SamplingConfig::getTopPMin, &tle::SamplingConfig::setTopPMin)
|
|
.def_property("top_p_reset_ids", &tle::SamplingConfig::getTopPResetIds, &tle::SamplingConfig::setTopPResetIds)
|
|
.def_property("top_p_decay", &tle::SamplingConfig::getTopPDecay, &tle::SamplingConfig::setTopPDecay)
|
|
.def_property("seed", &tle::SamplingConfig::getSeed, &tle::SamplingConfig::setSeed)
|
|
.def_property("temperature", &tle::SamplingConfig::getTemperature, &tle::SamplingConfig::setTemperature)
|
|
.def_property("min_tokens", &tle::SamplingConfig::getMinTokens, &tle::SamplingConfig::setMinTokens)
|
|
.def_property("beam_search_diversity_rate", &tle::SamplingConfig::getBeamSearchDiversityRate,
|
|
&tle::SamplingConfig::setBeamSearchDiversityRate)
|
|
.def_property("repetition_penalty", &tle::SamplingConfig::getRepetitionPenalty,
|
|
&tle::SamplingConfig::setRepetitionPenalty)
|
|
.def_property("presence_penalty", &tle::SamplingConfig::getPresencePenalty,
|
|
[](tle::SamplingConfig& self, std::optional<FloatType> v) { self.setPresencePenalty(v); })
|
|
.def_property(
|
|
"frequency_penalty", &tle::SamplingConfig::getFrequencyPenalty, &tle::SamplingConfig::setFrequencyPenalty)
|
|
.def_property("prompt_ignore_length", &tle::SamplingConfig::getPromptIgnoreLength,
|
|
&tle::SamplingConfig::setPromptIgnoreLength)
|
|
.def_property("length_penalty", &tle::SamplingConfig::getLengthPenalty, &tle::SamplingConfig::setLengthPenalty)
|
|
.def_property("early_stopping", &tle::SamplingConfig::getEarlyStopping, &tle::SamplingConfig::setEarlyStopping)
|
|
.def_property("no_repeat_ngram_size", &tle::SamplingConfig::getNoRepeatNgramSize,
|
|
&tle::SamplingConfig::setNoRepeatNgramSize)
|
|
.def_property("num_return_sequences", &tle::SamplingConfig::getNumReturnSequences,
|
|
&tle::SamplingConfig::setNumReturnSequences)
|
|
.def_property("min_p", &tle::SamplingConfig::getMinP, &tle::SamplingConfig::setMinP)
|
|
.def_property(
|
|
"beam_width_array", &tle::SamplingConfig::getBeamWidthArray, &tle::SamplingConfig::setBeamWidthArray)
|
|
.def(py::pickle(samplingConfigGetstate, samplingConfigSetstate));
|
|
|
|
auto additionalModelOutputGetstate
|
|
= [](tle::AdditionalModelOutput const& self) { return py::make_tuple(self.name, self.gatherContext); };
|
|
auto additionalModelOutputSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 2)
|
|
{
|
|
throw std::runtime_error("Invalid AdditionalModelOutput state!");
|
|
}
|
|
return tle::AdditionalModelOutput(state[0].cast<std::string>(), state[1].cast<bool>());
|
|
};
|
|
py::class_<tle::AdditionalModelOutput>(m, "AdditionalModelOutput")
|
|
.def(py::init<std::string, bool>(), py::arg("name"), py::arg("gather_context") = false)
|
|
.def_readwrite("name", &tle::AdditionalModelOutput::name)
|
|
.def_readwrite("gather_context", &tle::AdditionalModelOutput::gatherContext)
|
|
.def(py::pickle(additionalModelOutputGetstate, additionalModelOutputSetstate));
|
|
|
|
auto outputConfigGetstate = [](tle::OutputConfig const& self)
|
|
{
|
|
return py::make_tuple(self.returnLogProbs, self.returnContextLogits, self.returnGenerationLogits,
|
|
self.excludeInputFromOutput, self.returnEncoderOutput, self.returnPerfMetrics, self.additionalModelOutputs);
|
|
};
|
|
auto outputConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 7)
|
|
{
|
|
throw std::runtime_error("Invalid OutputConfig state!");
|
|
}
|
|
return tle::OutputConfig(state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<bool>(),
|
|
state[3].cast<bool>(), state[4].cast<bool>(), state[5].cast<bool>(),
|
|
state[6].cast<std::optional<std::vector<tle::AdditionalModelOutput>>>());
|
|
};
|
|
py::class_<tle::OutputConfig>(m, "OutputConfig")
|
|
.def(py::init<bool, bool, bool, bool, bool, bool, std::optional<std::vector<tle::AdditionalModelOutput>>>(),
|
|
py::arg("return_log_probs") = false, py::arg("return_context_logits") = false,
|
|
py::arg("return_generation_logits") = false, py::arg("exclude_input_from_output") = false,
|
|
py::arg("return_encoder_output") = false, py::arg("return_perf_metrics") = false,
|
|
py::arg("additional_model_outputs") = py::none())
|
|
.def_readwrite("return_log_probs", &tle::OutputConfig::returnLogProbs)
|
|
.def_readwrite("return_context_logits", &tle::OutputConfig::returnContextLogits)
|
|
.def_readwrite("return_generation_logits", &tle::OutputConfig::returnGenerationLogits)
|
|
.def_readwrite("exclude_input_from_output", &tle::OutputConfig::excludeInputFromOutput)
|
|
.def_readwrite("return_encoder_output", &tle::OutputConfig::returnEncoderOutput)
|
|
.def_readwrite("return_perf_metrics", &tle::OutputConfig::returnPerfMetrics)
|
|
.def_readwrite("additional_model_outputs", &tle::OutputConfig::additionalModelOutputs)
|
|
.def(py::pickle(outputConfigGetstate, outputConfigSetstate));
|
|
|
|
auto externalDraftTokensConfigGetstate = [](tle::ExternalDraftTokensConfig const& self)
|
|
{ return py::make_tuple(self.getTokens(), self.getLogits(), self.getAcceptanceThreshold()); };
|
|
auto externalDraftTokensConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 3)
|
|
{
|
|
throw std::runtime_error("Invalid ExternalDraftTokensConfig state!");
|
|
}
|
|
return tle::ExternalDraftTokensConfig(state[0].cast<VecTokens>(), state[1].cast<std::optional<Tensor>>(),
|
|
state[2].cast<std::optional<FloatType>>());
|
|
};
|
|
py::class_<tle::ExternalDraftTokensConfig>(m, "ExternalDraftTokensConfig")
|
|
.def(py::init<VecTokens, std::optional<Tensor>, std::optional<FloatType> const&, std::optional<bool>>(),
|
|
py::arg("tokens"), py::arg("logits") = py::none(), py::arg("acceptance_threshold") = py::none(),
|
|
py::arg("fast_logits") = py::none())
|
|
.def_property_readonly("tokens", &tle::ExternalDraftTokensConfig::getTokens)
|
|
.def_property_readonly("logits", &tle::ExternalDraftTokensConfig::getLogits)
|
|
.def_property_readonly("acceptance_threshold", &tle::ExternalDraftTokensConfig::getAcceptanceThreshold)
|
|
.def(py::pickle(externalDraftTokensConfigGetstate, externalDraftTokensConfigSetstate))
|
|
.def_property_readonly("fast_logits", &tle::ExternalDraftTokensConfig::getFastLogits);
|
|
|
|
auto promptTuningConfigGetstate = [](tle::PromptTuningConfig const& self)
|
|
{ return py::make_tuple(self.getEmbeddingTable(), self.getInputTokenExtraIds()); };
|
|
auto promptTuningConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 2)
|
|
{
|
|
throw std::runtime_error("Invalid PromptTuningConfig state!");
|
|
}
|
|
return tle::PromptTuningConfig(state[0].cast<Tensor>(), state[1].cast<std::optional<VecTokenExtraIds>>());
|
|
};
|
|
py::class_<tle::PromptTuningConfig>(m, "PromptTuningConfig")
|
|
.def(py::init<Tensor, std::optional<VecTokenExtraIds>>(), py::arg("embedding_table"),
|
|
py::arg("input_token_extra_ids") = py::none())
|
|
.def_property_readonly("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable)
|
|
.def_property_readonly("input_token_extra_ids", &tle::PromptTuningConfig::getInputTokenExtraIds)
|
|
.def(py::pickle(promptTuningConfigGetstate, promptTuningConfigSetstate));
|
|
|
|
auto loraConfigGetstate = [](tle::LoraConfig const& self)
|
|
{ return py::make_tuple(self.getTaskId(), self.getWeights(), self.getConfig()); };
|
|
auto loraConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 3)
|
|
{
|
|
throw std::runtime_error("Invalid LoraConfig state!");
|
|
}
|
|
return tle::LoraConfig(
|
|
state[0].cast<IdType>(), state[1].cast<std::optional<Tensor>>(), state[2].cast<std::optional<Tensor>>());
|
|
};
|
|
py::class_<tle::LoraConfig>(m, "LoraConfig")
|
|
.def(py::init<uint64_t, std::optional<Tensor>, std::optional<Tensor>>(), py::arg("task_id"),
|
|
py::arg("weights") = py::none(), py::arg("config") = py::none())
|
|
.def_property_readonly("task_id", &tle::LoraConfig::getTaskId)
|
|
.def_property_readonly("weights", &tle::LoraConfig::getWeights)
|
|
.def_property_readonly("config", &tle::LoraConfig::getConfig)
|
|
.def(py::pickle(loraConfigGetstate, loraConfigSetstate));
|
|
|
|
auto multimodalInputGetstate = [](tle::MultimodalInput const& self)
|
|
{ return py::make_tuple(self.getMultimodalHashes(), self.getMultimodalPositions(), self.getMultimodalLengths()); };
|
|
auto multimodalInputSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 3)
|
|
{
|
|
throw std::runtime_error("Invalid MultimodalInput state!");
|
|
}
|
|
return tle::MultimodalInput(state[0].cast<std::vector<std::vector<SizeType32>>>(),
|
|
state[1].cast<std::vector<SizeType32>>(), state[2].cast<std::vector<SizeType32>>());
|
|
};
|
|
py::class_<tle::MultimodalInput>(m, "MultimodalInput")
|
|
.def(py::init<std::vector<std::vector<SizeType32>>, std::vector<SizeType32>, std::vector<SizeType32>>(),
|
|
py::arg("multimodal_hashes"), py::arg("multimodal_positions"), py::arg("multimodal_lengths"))
|
|
.def_property_readonly("multimodal_hashes", &tle::MultimodalInput::getMultimodalHashes)
|
|
.def_property_readonly("multimodal_positions", &tle::MultimodalInput::getMultimodalPositions)
|
|
.def_property_readonly("multimodal_lengths", &tle::MultimodalInput::getMultimodalLengths)
|
|
.def(py::pickle(multimodalInputGetstate, multimodalInputSetstate));
|
|
|
|
auto MropeConfigGetstate = [](tle::MropeConfig const& self)
|
|
{ return py::make_tuple(self.getMRopeRotaryCosSin(), self.getMRopePositionDeltas()); };
|
|
auto MropeConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 2)
|
|
{
|
|
throw std::runtime_error("Invalid MropeConfig state!");
|
|
}
|
|
return tle::MropeConfig(state[0].cast<tle::Tensor>(), state[1].cast<SizeType32>());
|
|
};
|
|
py::class_<tle::MropeConfig>(m, "MropeConfig")
|
|
.def(py::init<Tensor, SizeType32>(), py::arg("mrope_rotary_cos_sin"), py::arg("mrope_position_deltas"))
|
|
.def_property_readonly("mrope_rotary_cos_sin", &tle::MropeConfig::getMRopeRotaryCosSin)
|
|
.def_property_readonly("mrope_position_deltas", &tle::MropeConfig::getMRopePositionDeltas)
|
|
.def(py::pickle(MropeConfigGetstate, MropeConfigSetstate));
|
|
|
|
auto lookaheadDecodingConfigGetstate = [](tle::LookaheadDecodingConfig const& self)
|
|
{ return py::make_tuple(self.getWindowSize(), self.getNgramSize(), self.getVerificationSetSize()); };
|
|
auto lookaheadDecodingConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 3)
|
|
{
|
|
throw std::runtime_error("Invalid LookaheadDecodingConfig state!");
|
|
}
|
|
return tle::LookaheadDecodingConfig(
|
|
state[0].cast<SizeType32>(), state[1].cast<SizeType32>(), state[2].cast<SizeType32>());
|
|
};
|
|
py::class_<tle::LookaheadDecodingConfig>(m, "LookaheadDecodingConfig")
|
|
.def(py::init<SizeType32, SizeType32, SizeType32>(), py::arg("max_window_size"), py::arg("max_ngram_size"),
|
|
py::arg("max_verification_set_size"))
|
|
.def_property_readonly("max_window_size", &tle::LookaheadDecodingConfig::getWindowSize)
|
|
.def_property_readonly("max_ngram_size", &tle::LookaheadDecodingConfig::getNgramSize)
|
|
.def_property_readonly("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize)
|
|
.def("calculate_speculative_resource", &tle::LookaheadDecodingConfig::calculateSpeculativeResource)
|
|
.def_static(
|
|
"calculate_speculative_resource_tuple", &tle::LookaheadDecodingConfig::calculateSpeculativeResourceTuple)
|
|
.def(py::pickle(lookaheadDecodingConfigGetstate, lookaheadDecodingConfigSetstate))
|
|
.def_static("get_default_lookahead_decoding_window",
|
|
[]() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingWindow; })
|
|
.def_static("get_default_lookahead_decoding_ngram",
|
|
[]() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingNgram; })
|
|
.def_static("get_default_lookahead_decoding_verification_set",
|
|
[]() { return tle::LookaheadDecodingConfig::kDefaultLookaheadDecodingVerificationSet; });
|
|
|
|
auto TokenRangeRetentionConfigGetstate = [](tle::KvCacheRetentionConfig::TokenRangeRetentionConfig const& self)
|
|
{ return py::make_tuple(self.tokenStart, self.tokenEnd, self.priority, self.durationMs); };
|
|
auto TokenRangeRetentionConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 4)
|
|
{
|
|
throw std::runtime_error("Invalid state!");
|
|
}
|
|
return tle::KvCacheRetentionConfig::TokenRangeRetentionConfig(state[0].cast<SizeType32>(),
|
|
state[1].cast<std::optional<SizeType32>>(), state[2].cast<tle::RetentionPriority>(),
|
|
state[3].cast<std::optional<std::chrono::milliseconds>>());
|
|
};
|
|
auto kvCacheRetentionConfigGetstate = [](tle::KvCacheRetentionConfig const& self)
|
|
{
|
|
return py::make_tuple(self.getTokenRangeRetentionConfigs(), self.getDecodeRetentionPriority(),
|
|
self.getDecodeDurationMs(), self.getTransferMode(), self.getDirectory());
|
|
};
|
|
auto kvCacheRetentionConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 5)
|
|
{
|
|
throw std::runtime_error("Invalid state!");
|
|
}
|
|
return tle::KvCacheRetentionConfig(
|
|
state[0].cast<std::vector<tle::KvCacheRetentionConfig::TokenRangeRetentionConfig>>(),
|
|
state[1].cast<tle::RetentionPriority>(), state[2].cast<std::optional<std::chrono::milliseconds>>(),
|
|
state[3].cast<tle::KvCacheTransferMode>(), state[4].cast<std::string>());
|
|
};
|
|
|
|
auto kvCacheRetentionConfig = py::class_<tle::KvCacheRetentionConfig>(m, "KvCacheRetentionConfig");
|
|
|
|
py::class_<tle::KvCacheRetentionConfig::TokenRangeRetentionConfig>(
|
|
kvCacheRetentionConfig, "TokenRangeRetentionConfig")
|
|
.def(py::init<SizeType32, std::optional<SizeType32>, tle::RetentionPriority,
|
|
std::optional<std::chrono::milliseconds>>(),
|
|
py::arg("token_start"), py::arg("token_end"), py::arg("priority"), py::arg("duration_ms") = py::none())
|
|
.def_readwrite("token_start", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenStart)
|
|
.def_readwrite("token_end", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::tokenEnd)
|
|
.def_readwrite("priority", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::priority)
|
|
.def_readwrite("duration_ms", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::durationMs)
|
|
.def(py::pickle(TokenRangeRetentionConfigGetstate, TokenRangeRetentionConfigSetstate))
|
|
.def("__eq__", &tle::KvCacheRetentionConfig::TokenRangeRetentionConfig::operator==);
|
|
|
|
// There's a circular dependency between the declaration of the TokenRangeRetentionPriority and
|
|
// KvCacheRetentionConfig bindings. Defer definition of the KvCacheRetentionConfig bindings until the
|
|
// TokenRangeRetentionPriority bindings have been defined.
|
|
kvCacheRetentionConfig
|
|
.def(py::init<std::vector<tle::KvCacheRetentionConfig::TokenRangeRetentionConfig>, tle::RetentionPriority,
|
|
std::optional<std::chrono::milliseconds>, tle::KvCacheTransferMode, std::string>(),
|
|
py::arg("token_range_retention_configs"),
|
|
py::arg("decode_retention_priority") = tle::KvCacheRetentionConfig::kDefaultRetentionPriority,
|
|
py::arg("decode_duration_ms") = py::none(),
|
|
py::arg_v("transfer_mode", tle::KvCacheTransferMode::DRAM, "DRAM"), py::arg("directory") = py::none())
|
|
.def_property_readonly(
|
|
"token_range_retention_configs", &tle::KvCacheRetentionConfig::getTokenRangeRetentionConfigs)
|
|
.def_property_readonly("decode_retention_priority", &tle::KvCacheRetentionConfig::getDecodeRetentionPriority)
|
|
.def_property_readonly("decode_duration_ms", &tle::KvCacheRetentionConfig::getDecodeDurationMs)
|
|
.def_property_readonly("transfer_mode", &tle::KvCacheRetentionConfig::getTransferMode)
|
|
.def_property_readonly("directory", &tle::KvCacheRetentionConfig::getDirectory)
|
|
.def(py::pickle(kvCacheRetentionConfigGetstate, kvCacheRetentionConfigSetstate))
|
|
.def("__eq__", &tle::KvCacheRetentionConfig::operator==);
|
|
|
|
auto ContextPhaseParamsGetState = [](tle::ContextPhaseParams const& self)
|
|
{
|
|
if (self.getState() != nullptr)
|
|
{
|
|
auto serializedState = self.getSerializedState();
|
|
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(),
|
|
py::bytes(serializedState.data(), serializedState.size()), self.getDraftTokens(), self.getDisaggId(),
|
|
self.getCtxDpRank(), self.getDisaggInfoEndpoint());
|
|
}
|
|
return py::make_tuple(self.getFirstGenTokens(), self.getReqId(), py::none(), self.getDraftTokens(),
|
|
self.getDisaggId(), self.getCtxDpRank(), self.getDisaggInfoEndpoint());
|
|
};
|
|
|
|
auto ContextPhaseParamsSetState = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 7)
|
|
{
|
|
throw std::runtime_error("Invalid ContextPhaseParams state!");
|
|
}
|
|
if (!state[2].is_none())
|
|
{
|
|
auto opaque_state = state[2].cast<py::bytes>();
|
|
auto opaque_state_str_view = std::string_view(opaque_state.cast<std::string_view>());
|
|
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
|
|
state[1].cast<tle::ContextPhaseParams::RequestIdType>(),
|
|
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
|
state[3].cast<std::optional<VecTokens>>(), state[4].cast<std::optional<std::int64_t>>(),
|
|
state[5].cast<std::optional<SizeType32>>(), state[6].cast<std::optional<std::string>>());
|
|
}
|
|
return std::make_unique<tle::ContextPhaseParams>(state[0].cast<VecTokens>(),
|
|
state[1].cast<tle::ContextPhaseParams::RequestIdType>(), state[3].cast<std::optional<VecTokens>>(),
|
|
state[4].cast<std::optional<std::int64_t>>(), state[5].cast<std::optional<SizeType32>>(),
|
|
state[6].cast<std::optional<std::string>>());
|
|
};
|
|
|
|
py::class_<tle::ContextPhaseParams>(m, "ContextPhaseParams")
|
|
.def(py::init(
|
|
[](VecTokens const& first_gen_tokens, tle::ContextPhaseParams::RequestIdType req_id,
|
|
std::optional<py::bytes> const& opaque_state, std::optional<VecTokens> const& draft_tokens,
|
|
std::optional<std::int64_t> const& disagg_id, std::optional<SizeType32> const& ctx_dp_rank,
|
|
std::optional<std::string> const& disagg_info_endpoint)
|
|
{
|
|
if (opaque_state)
|
|
{
|
|
auto opaque_state_str_view = std::string_view(opaque_state.value().cast<std::string_view>());
|
|
return std::make_unique<tle::ContextPhaseParams>(first_gen_tokens, req_id,
|
|
std::vector<char>(opaque_state_str_view.begin(), opaque_state_str_view.end()),
|
|
draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint);
|
|
}
|
|
return std::make_unique<tle::ContextPhaseParams>(
|
|
first_gen_tokens, req_id, draft_tokens, disagg_id, ctx_dp_rank, disagg_info_endpoint);
|
|
}),
|
|
py::arg("first_gen_tokens"), py::arg("req_id"), py::arg("opaque_state") = py::none(),
|
|
py::arg("draft_tokens") = py::none(), py::arg("disagg_id") = py::none(),
|
|
py::arg("ctx_dp_rank") = py::none(), py::arg("disagg_info_endpoint") = py::none())
|
|
.def_property("first_gen_tokens", &tle::ContextPhaseParams::getFirstGenTokens,
|
|
&tle::ContextPhaseParams::setFirstGenTokens)
|
|
.def_property(
|
|
"draft_tokens", &tle::ContextPhaseParams::getDraftTokens, &tle::ContextPhaseParams::setDraftTokens)
|
|
.def_property("req_id", &tle::ContextPhaseParams::getReqId, &tle::ContextPhaseParams::setReqId)
|
|
.def_property("disagg_id", &tle::ContextPhaseParams::getDisaggId, &tle::ContextPhaseParams::setDisaggId)
|
|
.def_property("ctx_dp_rank", &tle::ContextPhaseParams::getCtxDpRank, &tle::ContextPhaseParams::setCtxDpRank)
|
|
.def_property("disagg_info_endpoint", &tle::ContextPhaseParams::getDisaggInfoEndpoint,
|
|
&tle::ContextPhaseParams::setDisaggInfoEndpoint)
|
|
.def_property_readonly("opaque_state",
|
|
[](tle::ContextPhaseParams const& self)
|
|
{
|
|
std::optional<py::bytes> opaque_state{std::nullopt};
|
|
if (self.getState() != nullptr)
|
|
{
|
|
auto serializedState = self.getSerializedState();
|
|
opaque_state = py::bytes(serializedState.data(), serializedState.size());
|
|
}
|
|
return opaque_state;
|
|
})
|
|
.def(py::pickle(ContextPhaseParamsGetState, ContextPhaseParamsSetState));
|
|
|
|
auto EagleDecodingConfigGetstate = [](tle::EagleConfig const& self)
|
|
{
|
|
return py::make_tuple(self.getEagleChoices(), self.isGreedySampling(), self.getPosteriorThreshold(),
|
|
self.useDynamicTree(), self.getDynamicTreeMaxTopK());
|
|
};
|
|
auto EagleDecodingConfigSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 5)
|
|
{
|
|
throw std::runtime_error("Invalid EagleConfig state!");
|
|
}
|
|
return tle::EagleConfig(state[0].cast<std::optional<tle::EagleChoices>>(), state[1].cast<bool>(),
|
|
state[2].cast<std::optional<float>>(), state[3].cast<bool>(), state[4].cast<std::optional<SizeType32>>());
|
|
};
|
|
py::class_<tle::EagleConfig>(m, "EagleConfig")
|
|
.def(py::init<std::optional<tle::EagleChoices>, bool, std::optional<float>, bool, std::optional<SizeType32>>(),
|
|
py::arg("eagle_choices") = py::none(), py::arg("greedy_sampling") = true,
|
|
py::arg("posterior_threshold") = py::none(), py::arg("use_dynamic_tree") = false,
|
|
py::arg("dynamic_tree_max_topK") = py::none())
|
|
.def_property_readonly("eagle_choices", &tle::EagleConfig::getEagleChoices)
|
|
.def_property_readonly("greedy_sampling", &tle::EagleConfig::isGreedySampling)
|
|
.def_property_readonly("posterior_threshold", &tle::EagleConfig::getPosteriorThreshold)
|
|
.def_property_readonly("use_dynamic_tree", &tle::EagleConfig::useDynamicTree)
|
|
.def_property_readonly("dynamic_tree_max_topK", &tle::EagleConfig::getDynamicTreeMaxTopK)
|
|
.def(py::pickle(EagleDecodingConfigGetstate, EagleDecodingConfigSetstate));
|
|
|
|
// Guided decoding params
|
|
auto pyGuidedDecodingParams = py::class_<tle::GuidedDecodingParams>(m, "GuidedDecodingParams");
|
|
|
|
py::enum_<tle::GuidedDecodingParams::GuideType>(pyGuidedDecodingParams, "GuideType")
|
|
.value("JSON", tle::GuidedDecodingParams::GuideType::kJSON)
|
|
.value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA)
|
|
.value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX)
|
|
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR)
|
|
.value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG);
|
|
|
|
auto guidedDecodingParamsGetstate
|
|
= [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); };
|
|
|
|
auto guidedDecodingParamsSetstate = [](py::tuple state)
|
|
{
|
|
if (state.size() != 2)
|
|
{
|
|
throw std::runtime_error("Invalid GuidedDecodingParams state!");
|
|
}
|
|
return tle::GuidedDecodingParams(
|
|
state[0].cast<tle::GuidedDecodingParams::GuideType>(), state[1].cast<std::optional<std::string>>());
|
|
};
|
|
|
|
pyGuidedDecodingParams
|
|
.def(py::init<tle::GuidedDecodingParams::GuideType, std::optional<std::string>>(), py::arg("guide_type"),
|
|
py::arg("guide") = py::none())
|
|
.def_property_readonly("guide_type", &tle::GuidedDecodingParams::getGuideType)
|
|
.def_property_readonly("guide", &tle::GuidedDecodingParams::getGuide)
|
|
.def(py::pickle(guidedDecodingParamsGetstate, guidedDecodingParamsSetstate));
|
|
|
|
auto requestGetstate = [](tle::Request const& self)
|
|
{
|
|
return py::make_tuple(self.getInputTokenIds(), self.getMaxTokens(), self.getStreaming(),
|
|
self.getSamplingConfig(), self.getOutputConfig(), self.getEndId(), self.getPadId(), self.getPositionIds(),
|
|
self.getBadWords(), self.getStopWords(), self.getEmbeddingBias(), self.getExternalDraftTokensConfig(),
|
|
self.getPromptTuningConfig(), self.getMultimodalInput(), self.getMultimodalEmbedding(),
|
|
self.getMropeConfig(), self.getLoraConfig(), self.getLookaheadConfig(), self.getKvCacheRetentionConfig(),
|
|
self.getLogitsPostProcessorName(), self.getLogitsPostProcessor(), self.getEncoderInputTokenIds(),
|
|
self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(),
|
|
self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(),
|
|
self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(),
|
|
self.getGuidedDecodingParams(), self.getCacheSaltID());
|
|
};
|
|
auto requestSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 34)
|
|
{
|
|
throw std::runtime_error("Invalid Request state!");
|
|
}
|
|
return std::make_unique<tle::Request>(state[0].cast<VecTokens>(), state[1].cast<SizeType32>(),
|
|
state[2].cast<bool>(), state[3].cast<tle::SamplingConfig>(), state[4].cast<tle::OutputConfig>(),
|
|
state[5].cast<std::optional<SizeType32>>(), state[6].cast<std::optional<SizeType32>>(),
|
|
state[7].cast<std::optional<std::vector<SizeType32>>>(),
|
|
state[8].cast<std::optional<std::list<VecTokens>>>(), state[9].cast<std::optional<std::list<VecTokens>>>(),
|
|
state[10].cast<std::optional<Tensor>>(), state[11].cast<std::optional<tle::ExternalDraftTokensConfig>>(),
|
|
state[12].cast<std::optional<tle::PromptTuningConfig>>(),
|
|
state[13].cast<std::optional<tle::MultimodalInput>>(), state[14].cast<std::optional<Tensor>>(),
|
|
state[15].cast<std::optional<tle::MropeConfig>>(), state[16].cast<std::optional<tle::LoraConfig>>(),
|
|
state[17].cast<std::optional<tle::LookaheadDecodingConfig>>(),
|
|
state[18].cast<std::optional<tle::KvCacheRetentionConfig>>(), state[19].cast<std::optional<std::string>>(),
|
|
state[20].cast<std::optional<tle::LogitsPostProcessor>>(), state[21].cast<std::optional<VecTokens>>(),
|
|
state[22].cast<std::optional<IdType>>(), state[23].cast<bool>(), state[24].cast<tle::PriorityType>(),
|
|
state[25].cast<tle::RequestType>(), state[26].cast<std::optional<tle::ContextPhaseParams>>(),
|
|
state[27].cast<std::optional<tle::Tensor>>(), state[28].cast<std::optional<SizeType32>>(),
|
|
state[29].cast<std::optional<tle::Tensor>>(), 1, state[30].cast<std::optional<tle::EagleConfig>>(),
|
|
state[31].cast<std::optional<tle::Tensor>>(), state[32].cast<std::optional<tle::GuidedDecodingParams>>(),
|
|
state[33].cast<std::optional<tle::CacheSaltIDType>>());
|
|
};
|
|
|
|
py::class_<tle::Request> request(m, "Request", pybind11::dynamic_attr());
|
|
request
|
|
.def(py::init<tle::VecTokens, // inputTokenIds
|
|
tle::SizeType32, // maxTokens
|
|
bool, // streaming
|
|
tle::SamplingConfig const&, // samplingConfig
|
|
tle::OutputConfig const&, // outputConfig
|
|
std::optional<tle::SizeType32> const&, // endId
|
|
std::optional<tle::SizeType32> const&, // padId
|
|
std::optional<std::vector<SizeType32>>, // positionIds
|
|
std::optional<std::list<tle::VecTokens>>, // badWords
|
|
std::optional<std::list<tle::VecTokens>>, // stopWords
|
|
std::optional<tle::Tensor>, // embeddingBias
|
|
std::optional<tle::ExternalDraftTokensConfig>, // externalDraftTokensConfig
|
|
std::optional<tle::PromptTuningConfig>, // pTuningConfig
|
|
std::optional<tle::MultimodalInput>, // multimodalInput
|
|
std::optional<tle::Tensor>, // multimodalEmbedding
|
|
std::optional<tle::MropeConfig>, // mRopeConfig
|
|
std::optional<tle::LoraConfig>, // loraConfig
|
|
std::optional<tle::LookaheadDecodingConfig>, // lookaheadConfig
|
|
std::optional<tle::KvCacheRetentionConfig>, // kvCacheRetentionConfig
|
|
std::optional<std::string>, // logitsPostProcessorName
|
|
std::optional<tle::LogitsPostProcessor>, // logitsPostProcessor
|
|
std::optional<tle::VecTokens>, // encoderInputTokenIds
|
|
std::optional<tle::IdType>, // clientId
|
|
bool, // returnAllGeneratedTokens
|
|
tle::PriorityType, // priority
|
|
tle::RequestType, // type
|
|
std::optional<tle::ContextPhaseParams>, // contextPhaseParams
|
|
std::optional<tle::Tensor>, // encoderInputFeatures
|
|
std::optional<tle::SizeType32>, // encoderOutputLength
|
|
std::optional<tle::Tensor>, // crossAttentionMask
|
|
SizeType32, // numReturnSequences
|
|
std::optional<tle::EagleConfig>, // eagleConfig
|
|
std::optional<tle::Tensor>, // skipCrossAttnBlocks
|
|
std::optional<tle::GuidedDecodingParams>, // guidedDecodingParams
|
|
std::optional<tle::SizeType32>, // languageAdapterUid
|
|
std::optional<tle::MillisecondsType>, // allottedTimeMs
|
|
std::optional<tle::CacheSaltIDType> // cacheSaltID
|
|
>(),
|
|
// clang-format off
|
|
py::arg("input_token_ids"),
|
|
py::arg("max_tokens"),
|
|
py::kw_only(),
|
|
py::arg("streaming") = false,
|
|
py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"),
|
|
py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"),
|
|
py::arg("end_id") = py::none(),
|
|
py::arg("pad_id") = py::none(),
|
|
py::arg("position_ids") = py::none(),
|
|
py::arg("bad_words") = py::none(),
|
|
py::arg("stop_words") = py::none(),
|
|
py::arg("embedding_bias") = py::none(),
|
|
py::arg("external_draft_tokens_config") = py::none(),
|
|
py::arg("prompt_tuning_config") = py::none(),
|
|
py::arg("multimodal_input") = py::none(),
|
|
py::arg("multimodal_embedding") = py::none(),
|
|
py::arg("mrope_config") = py::none(),
|
|
py::arg("lora_config") = py::none(),
|
|
py::arg("lookahead_config") = py::none(),
|
|
py::arg("kv_cache_retention_config") = py::none(),
|
|
py::arg("logits_post_processor_name") = py::none(),
|
|
py::arg("logits_post_processor") = py::none(),
|
|
py::arg("encoder_input_token_ids") = py::none(),
|
|
py::arg("client_id") = py::none(),
|
|
py::arg("return_all_generated_tokens") = false,
|
|
py::arg("priority") = tle::Request::kDefaultPriority,
|
|
py::arg_v("type", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION,
|
|
"RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION"),
|
|
py::arg("context_phase_params") = py::none(),
|
|
py::arg("encoder_input_features") = py::none(),
|
|
py::arg("encoder_output_length") = py::none(),
|
|
py::arg("cross_attention_mask") = py::none(),
|
|
py::arg("num_return_sequences") = 1,
|
|
py::arg("eagle_config") = py::none(),
|
|
py::arg("skip_cross_attn_blocks") = py::none(),
|
|
py::arg("guided_decoding_params") = py::none(),
|
|
py::arg("language_adapter_uid") = py::none(),
|
|
py::arg("allotted_time_ms") = py::none(),
|
|
py::arg("cache_salt_id") = py::none()
|
|
) // clang-format on
|
|
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
|
|
.def_property_readonly("max_tokens", &tle::Request::getMaxTokens)
|
|
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
|
.def_property("sampling_config", &tle::Request::getSamplingConfig, &tle::Request::setSamplingConfig)
|
|
.def_property("output_config", &tle::Request::getOutputConfig, &tle::Request::setOutputConfig)
|
|
.def_property("end_id", &tle::Request::getEndId, &tle::Request::setEndId)
|
|
.def_property("pad_id", &tle::Request::getPadId, &tle::Request::setPadId)
|
|
.def_property("position_ids", &tle::Request::getPositionIds, &tle::Request::setPositionIds)
|
|
.def_property("bad_words", &tle::Request::getBadWords, &tle::Request::setBadWords)
|
|
.def_property("stop_words", &tle::Request::getStopWords, &tle::Request::setStopWords)
|
|
.def_property("embedding_bias", &tle::Request::getEmbeddingBias, &tle::Request::setEmbeddingBias)
|
|
.def_property("external_draft_tokens_config", &tle::Request::getExternalDraftTokensConfig,
|
|
&tle::Request::setExternalDraftTokensConfig)
|
|
.def_property(
|
|
"prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig)
|
|
.def_property("multimodal_input", &tle::Request::getMultimodalInput, &tle::Request::setMultimodalInput)
|
|
.def_property(
|
|
"multimodal_embedding", &tle::Request::getMultimodalEmbedding, &tle::Request::setMultimodalEmbedding)
|
|
.def_property("mrope_config", &tle::Request::getMropeConfig, &tle::Request::setMropeConfig)
|
|
.def_property("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig)
|
|
.def_property("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig)
|
|
.def_property("kv_cache_retention_config", &tle::Request::getKvCacheRetentionConfig,
|
|
&tle::Request::setKvCacheRetentionConfig)
|
|
.def_property("logits_post_processor_name", &tle::Request::getLogitsPostProcessorName,
|
|
&tle::Request::setLogitsPostProcessorName)
|
|
.def_property(
|
|
"logits_post_processor", &tle::Request::getLogitsPostProcessor, &tle::Request::setLogitsPostProcessor)
|
|
.def_property(
|
|
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds)
|
|
.def_property("client_id", &tle::Request::getClientId, &tle::Request::setClientId)
|
|
.def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens,
|
|
&tle::Request::setReturnAllGeneratedTokens)
|
|
.def_property("request_type", &tle::Request::getRequestType, &tle::Request::setRequestType)
|
|
.def_property(
|
|
"encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures)
|
|
.def_property(
|
|
"cross_attention_mask", &tle::Request::getCrossAttentionMask, &tle::Request::setCrossAttentionMask)
|
|
.def_property("eagle_config", &tle::Request::getEagleConfig, &tle::Request::setEagleConfig)
|
|
.def_property(
|
|
"skip_cross_attn_blocks", &tle::Request::getSkipCrossAttnBlocks, &tle::Request::setSkipCrossAttnBlocks)
|
|
.def_property(
|
|
"guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams)
|
|
.def_property("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs)
|
|
.def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID)
|
|
.def_property(
|
|
"context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams)
|
|
.def(py::pickle(requestGetstate, requestSetstate));
|
|
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
|
|
|
|
py::class_<tle::SpeculativeDecodingFastLogitsInfo>(m, "SpeculativeDecodingFastLogitsInfo")
|
|
.def(py::init<>())
|
|
.def_readwrite("draft_request_id", &tle::SpeculativeDecodingFastLogitsInfo::draftRequestId)
|
|
.def_readwrite("draft_participant_id", &tle::SpeculativeDecodingFastLogitsInfo::draftParticipantId)
|
|
.def("to_tensor", &tle::SpeculativeDecodingFastLogitsInfo::toTensor);
|
|
|
|
auto requestPerfMetrics = py::class_<tle::RequestPerfMetrics>(m, "RequestPerfMetrics");
|
|
|
|
auto timingMetricsGetstate = [](tle::RequestPerfMetrics::TimingMetrics const& self)
|
|
{
|
|
return py::make_tuple(self.arrivalTime, self.firstScheduledTime, self.firstTokenTime, self.lastTokenTime,
|
|
self.kvCacheTransferStart, self.kvCacheTransferEnd, self.kvCacheSize);
|
|
};
|
|
auto timingMetricsSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 7)
|
|
{
|
|
throw std::runtime_error("Invalid TimingMetrics state!");
|
|
}
|
|
return tle::RequestPerfMetrics::TimingMetrics{state[0].cast<tle::RequestPerfMetrics::TimePoint>(),
|
|
state[1].cast<tle::RequestPerfMetrics::TimePoint>(), state[2].cast<tle::RequestPerfMetrics::TimePoint>(),
|
|
state[3].cast<tle::RequestPerfMetrics::TimePoint>(), state[4].cast<tle::RequestPerfMetrics::TimePoint>(),
|
|
state[5].cast<tle::RequestPerfMetrics::TimePoint>(), state[6].cast<size_t>()};
|
|
};
|
|
py::class_<tle::RequestPerfMetrics::TimingMetrics>(m, "TimingMetrics")
|
|
.def(py::init<>())
|
|
.def_readwrite("arrival_time", &tle::RequestPerfMetrics::TimingMetrics::arrivalTime)
|
|
.def_readwrite("first_scheduled_time", &tle::RequestPerfMetrics::TimingMetrics::firstScheduledTime)
|
|
.def_readwrite("first_token_time", &tle::RequestPerfMetrics::TimingMetrics::firstTokenTime)
|
|
.def_readwrite("last_token_time", &tle::RequestPerfMetrics::TimingMetrics::lastTokenTime)
|
|
.def_readwrite("kv_cache_transfer_start", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferStart)
|
|
.def_readwrite("kv_cache_transfer_end", &tle::RequestPerfMetrics::TimingMetrics::kvCacheTransferEnd)
|
|
.def_readwrite("kv_cache_size", &tle::RequestPerfMetrics::TimingMetrics::kvCacheSize)
|
|
.def(py::pickle(timingMetricsGetstate, timingMetricsSetstate));
|
|
|
|
auto kvCacheMetricsGetstate = [](tle::RequestPerfMetrics::KvCacheMetrics const& self)
|
|
{
|
|
return py::make_tuple(self.numTotalAllocatedBlocks, self.numNewAllocatedBlocks, self.numReusedBlocks,
|
|
self.numMissedBlocks, self.kvCacheHitRate);
|
|
};
|
|
auto kvCacheMetricsSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 5)
|
|
{
|
|
throw std::runtime_error("Invalid KvCacheMetrics state!");
|
|
}
|
|
return tle::RequestPerfMetrics::KvCacheMetrics{state[0].cast<SizeType32>(), state[1].cast<SizeType32>(),
|
|
state[2].cast<SizeType32>(), state[3].cast<SizeType32>(), state[4].cast<float>()};
|
|
};
|
|
py::class_<tle::RequestPerfMetrics::KvCacheMetrics>(m, "KvCacheMetrics")
|
|
.def(py::init<>())
|
|
.def_readwrite("num_total_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numTotalAllocatedBlocks)
|
|
.def_readwrite("num_new_allocated_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numNewAllocatedBlocks)
|
|
.def_readwrite("num_reused_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numReusedBlocks)
|
|
.def_readwrite("num_missed_blocks", &tle::RequestPerfMetrics::KvCacheMetrics::numMissedBlocks)
|
|
.def_readwrite("kv_cache_hit_rate", &tle::RequestPerfMetrics::KvCacheMetrics::kvCacheHitRate)
|
|
.def(py::pickle(kvCacheMetricsGetstate, kvCacheMetricsSetstate));
|
|
|
|
auto speculativeDecodingMetricsGetstate = [](tle::RequestPerfMetrics::SpeculativeDecodingMetrics const& self)
|
|
{ return py::make_tuple(self.acceptanceRate, self.totalAcceptedDraftTokens, self.totalDraftTokens); };
|
|
auto speculativeDecodingMetricsSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 3)
|
|
{
|
|
throw std::runtime_error("Invalid SpeculativeDecodingMetrics state!");
|
|
}
|
|
return tle::RequestPerfMetrics::SpeculativeDecodingMetrics{
|
|
state[0].cast<float>(), state[1].cast<SizeType32>(), state[2].cast<SizeType32>()};
|
|
};
|
|
|
|
py::class_<tle::RequestPerfMetrics::SpeculativeDecodingMetrics>(m, "SpeculativeDecodingMetrics")
|
|
.def(py::init<>())
|
|
.def_readwrite("acceptance_rate", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::acceptanceRate)
|
|
.def_readwrite("total_accepted_draft_tokens",
|
|
&tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalAcceptedDraftTokens)
|
|
.def_readwrite("total_draft_tokens", &tle::RequestPerfMetrics::SpeculativeDecodingMetrics::totalDraftTokens)
|
|
.def(py::pickle(speculativeDecodingMetricsGetstate, speculativeDecodingMetricsSetstate));
|
|
|
|
auto requestPerfMetricsGetstate = [](tle::RequestPerfMetrics const& self)
|
|
{
|
|
return py::make_tuple(self.timingMetrics, self.kvCacheMetrics, self.speculativeDecoding, self.firstIter,
|
|
self.lastIter, self.iter);
|
|
};
|
|
auto requestPerfMetricsSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 6)
|
|
{
|
|
throw std::runtime_error("Invalid RequestPerfMetrics state!");
|
|
}
|
|
return tle::RequestPerfMetrics{state[0].cast<tle::RequestPerfMetrics::TimingMetrics>(),
|
|
state[1].cast<tle::RequestPerfMetrics::KvCacheMetrics>(),
|
|
state[2].cast<tle::RequestPerfMetrics::SpeculativeDecodingMetrics>(),
|
|
state[3].cast<std::optional<tle::IterationType>>(), state[4].cast<std::optional<tle::IterationType>>(),
|
|
state[5].cast<std::optional<tle::IterationType>>()};
|
|
};
|
|
|
|
// There's a circular dependency between the declaration of the TimingMetrics and RequestPerfMetrics bindings.
|
|
// Defer definition of the RequestPerfMetrics bindings until the TimingMetrics have been defined.
|
|
requestPerfMetrics.def(py::init<>())
|
|
.def_readwrite("timing_metrics", &tle::RequestPerfMetrics::timingMetrics)
|
|
.def_readwrite("kv_cache_metrics", &tle::RequestPerfMetrics::kvCacheMetrics)
|
|
.def_readwrite("speculative_decoding", &tle::RequestPerfMetrics::speculativeDecoding)
|
|
.def_readwrite("first_iter", &tle::RequestPerfMetrics::firstIter)
|
|
.def_readwrite("last_iter", &tle::RequestPerfMetrics::lastIter)
|
|
.def_readwrite("iter", &tle::RequestPerfMetrics::iter)
|
|
.def(py::pickle(requestPerfMetricsGetstate, requestPerfMetricsSetstate));
|
|
|
|
py::class_<tle::AdditionalOutput>(m, "AdditionalOutput")
|
|
.def(py::init([](std::string const& name, tle::Tensor const& output)
|
|
{ return std::make_unique<tle::AdditionalOutput>(name, output); }))
|
|
.def_readwrite("name", &tle::AdditionalOutput::name)
|
|
.def_readwrite("output", &tle::AdditionalOutput::output);
|
|
|
|
auto resultSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 14)
|
|
{
|
|
throw std::runtime_error("Invalid Request state!");
|
|
}
|
|
tle::Result result;
|
|
result.isFinal = state[0].cast<bool>();
|
|
result.outputTokenIds = state[1].cast<std::vector<VecTokens>>();
|
|
result.cumLogProbs = state[2].cast<std::optional<std::vector<float>>>();
|
|
result.logProbs = state[3].cast<std::optional<std::vector<std::vector<float>>>>();
|
|
result.contextLogits = state[4].cast<std::optional<Tensor>>();
|
|
result.generationLogits = state[5].cast<std::optional<Tensor>>();
|
|
result.encoderOutput = state[6].cast<std::optional<Tensor>>();
|
|
result.finishReasons = state[7].cast<std::vector<tle::FinishReason>>();
|
|
result.sequenceIndex = state[8].cast<SizeType32>();
|
|
result.isSequenceFinal = state[9].cast<bool>();
|
|
result.decodingIter = state[10].cast<SizeType32>();
|
|
result.avgDecodedTokensPerIter = state[11].cast<float>();
|
|
result.contextPhaseParams = state[12].cast<std::optional<tle::ContextPhaseParams>>();
|
|
result.requestPerfMetrics = state[13].cast<std::optional<tle::RequestPerfMetrics>>();
|
|
return std::make_unique<tle::Result>(result);
|
|
};
|
|
|
|
auto resultGetstate = [](tle::Result const& self)
|
|
{
|
|
return py::make_tuple(self.isFinal, self.outputTokenIds, self.cumLogProbs, self.logProbs, self.contextLogits,
|
|
self.generationLogits, self.encoderOutput, self.finishReasons, self.sequenceIndex, self.isSequenceFinal,
|
|
self.decodingIter, self.avgDecodedTokensPerIter, self.contextPhaseParams, self.requestPerfMetrics);
|
|
};
|
|
|
|
py::class_<tle::Result>(m, "Result")
|
|
.def(py::init<>())
|
|
.def_readwrite("is_final", &tle::Result::isFinal)
|
|
.def_readwrite("output_token_ids", &tle::Result::outputTokenIds)
|
|
.def_readwrite("cum_log_probs", &tle::Result::cumLogProbs)
|
|
.def_readwrite("log_probs", &tle::Result::logProbs)
|
|
.def_readwrite("context_logits", &tle::Result::contextLogits)
|
|
.def_readwrite("generation_logits", &tle::Result::generationLogits)
|
|
.def_readwrite("spec_dec_fast_logits_info", &tle::Result::specDecFastLogitsInfo)
|
|
.def_readwrite("encoder_output", &tle::Result::encoderOutput)
|
|
.def_readwrite("finish_reasons", &tle::Result::finishReasons)
|
|
.def_readwrite("sequence_index", &tle::Result::sequenceIndex)
|
|
.def_readwrite("is_sequence_final", &tle::Result::isSequenceFinal)
|
|
.def_readwrite("decoding_iter", &tle::Result::decodingIter)
|
|
.def_readwrite("avg_decoded_tokens_per_iter", &tle::Result::avgDecodedTokensPerIter)
|
|
.def_readwrite("context_phase_params", &tle::Result::contextPhaseParams)
|
|
.def_readwrite("request_perf_metrics", &tle::Result::requestPerfMetrics)
|
|
.def_readwrite("additional_outputs", &tle::Result::additionalOutputs)
|
|
.def(py::pickle(resultGetstate, resultSetstate));
|
|
|
|
m.def("deserialize_result",
|
|
[](std::string& x)
|
|
{
|
|
std::istringstream is(x);
|
|
return tle::serialize_utils::deserialize<tle::Result>(is);
|
|
});
|
|
|
|
auto responseGetstate = [](tle::Response const& self)
|
|
{ return py::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); };
|
|
|
|
auto responseSetstate = [](py::tuple const& state)
|
|
{
|
|
if (state.size() != 3)
|
|
{
|
|
throw std::runtime_error("Invalid Request state!");
|
|
}
|
|
return std::make_unique<tle::Response>(
|
|
state[0].cast<SizeType32>(), state[1].cast<tle::Result>(), state[2].cast<SizeType32>());
|
|
};
|
|
|
|
py::class_<tle::Response>(m, "Response")
|
|
.def(py::init<IdType, std::string, std::optional<IdType>>(), py::arg("request_id"), py::arg("error_msg"),
|
|
py::arg("client_id") = std::nullopt)
|
|
.def(py::init<IdType, tle::Result, std::optional<IdType>>(), py::arg("request_id"), py::arg("result"),
|
|
py::arg("client_id") = std::nullopt)
|
|
.def_property_readonly("request_id", &tle::Response::getRequestId)
|
|
.def_property_readonly("client_id", &tle::Response::getClientId)
|
|
.def("has_error", &tle::Response::hasError)
|
|
.def_property_readonly("error_msg", &tle::Response::getErrorMsg)
|
|
.def_property_readonly("result", &tle::Response::getResult)
|
|
.def("clear_context_logits",
|
|
[](tle::Response& self)
|
|
{
|
|
if (!self.hasError())
|
|
{
|
|
auto& result = const_cast<tle::Result&>(self.getResult());
|
|
result.contextLogits.reset();
|
|
}
|
|
})
|
|
.def("clear_generation_logits",
|
|
[](tle::Response& self)
|
|
{
|
|
if (!self.hasError())
|
|
{
|
|
auto& result = const_cast<tle::Result&>(self.getResult());
|
|
result.generationLogits.reset();
|
|
}
|
|
})
|
|
.def(py::pickle(responseGetstate, responseSetstate));
|
|
}
|
|
|
|
} // namespace tensorrt_llm::pybind::executor
|