/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "llmRequest.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/pybind/utils/bindTypes.h" #include "tensorrt_llm/runtime/torch.h" #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/runtime/torchView.h" #include #include #include #include #include #include #include namespace tb = tensorrt_llm::batch_manager; namespace tr = tensorrt_llm::runtime; namespace tle = tensorrt_llm::executor; using namespace tensorrt_llm::pybind::batch_manager; using LlmRequestPtr = std::shared_ptr; using RequestList = std::list; namespace { std::optional from_torch(std::optional torchPtr) { if (torchPtr) { return tr::TorchView::of(torchPtr.value()); } return std::nullopt; } } // namespace std::optional LlmRequest::callbackAdapter( std::optional callback) { if (!callback) { return std::nullopt; } return [callback](RequestIdType reqId, tensorrt_llm::runtime::ITensor::SharedPtr& tensor, tensorrt_llm::batch_manager::LlmRequest::BeamTokens const& tokens, tensorrt_llm::runtime::BufferManager::CudaStreamPtr stream, std::optional clientId) { at::Tensor atTensor = tr::Torch::tensor(tensor); callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap(), clientId); }; } std::shared_ptr LlmRequest::toTrtLlm() const { auto embeddingBias = from_torch(mEmbeddingBias); auto badWordsList = from_torch(mBadWordsList); auto stopWordsList = from_torch(mStopWordsList); auto promptEmbeddingTable = from_torch(mPromptEmbeddingTable); auto loraWeights = from_torch(mLoraWeights); auto loraConfig = from_torch(mLoraConfig); auto draftLogits = from_torch(mDraftLogits); auto encoderInputFeatures = from_torch(mEncoderInputFeatures); return std::make_shared(mRequestId, mMaxNewTokens, std::make_shared>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId, embeddingBias, badWordsList, stopWordsList, mPositionIds, promptEmbeddingTable, mPromptVocabSize, mLoraTaskId, loraWeights, loraConfig, mLookaheadConfig, returnLogProbs(), mReturnContextLogits, mReturnGenerationLogits, mDraftTokens, draftLogits, mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor), mApplyLogitsPostProcessorBatched, mEncoderTokens, mReturnEncoderOutput, mClientId, mPriority, encoderInputFeatures, mEncoderOutputLength, tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, mInputTokenExtraIds, mNumReturnSequences); } void LlmRequest::initBindings(py::module_& m) { py::class_(m, "LlmRequest") .def(py::init, std::optional, std::optional, std::optional, std::optional, std::optional>, std::optional, std::optional, std::optional, std::optional, std::optional, std::optional, bool, bool, bool, std::optional, std::optional, bool, std::optional, bool, std::optional, bool, std::optional, executor::PriorityType, std::optional, std::optional, std::optional, LlmRequest::SizeType32>(), py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"), py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt, py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt, py::arg("stop_words_list") = std::nullopt, py::arg("position_ids") = std::nullopt, py::arg("prompt_embedding_table") = std::nullopt, py::arg("prompt_vocab_size") = std::nullopt, py::arg("lora_task_id") = std::nullopt, py::arg("lora_weights") = std::nullopt, py::arg("lora_config") = std::nullopt, py::arg("lookahead_config") = std::nullopt, py::arg("return_log_probs") = false, py::arg("return_context_logits") = false, py::arg("return_generation_logits") = false, py::arg("draft_tokens") = std::nullopt, py::arg("draft_logits") = std::nullopt, py::arg("exclude_input_from_output") = false, py::arg("logits_post_processor") = std::nullopt, py::arg("apply_logits_post_processor_batched") = false, py::arg("encoder_input_tokens") = std::nullopt, py::arg("return_encoder_output") = false, py::arg("client_id") = std::nullopt, py::arg("priority") = executor::Request::kDefaultPriority, py::arg("encoder_input_features") = std::nullopt, py::arg("encoder_output_length") = std::nullopt, py::arg("input_token_extra_ids") = std::nullopt, py::arg("num_return_sequences") = 1) .def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam")) .def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens) .def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos")) .def("get_tokens", py::overload_cast(&LlmRequest::getTokens, py::const_), py::arg("beam")) .def("get_tokens", py::overload_cast<>(&LlmRequest::getTokens, py::const_)) .def_property_readonly("max_num_generated_tokens", &LlmRequest::getMaxNumGeneratedTokens) .def("add_new_token", &LlmRequest::addNewToken, py::arg("token"), py::arg("beam")) .def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens")) .def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens")) .def("pause", &LlmRequest::pause, py::arg("max_input_len")) .def_property("max_sent_token_len", &LlmRequest::getMaxSentTokenLen, &LlmRequest::setMaxSentTokenLen) .def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable) .def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize) .def_property_readonly("lora_task_id", &LlmRequest::getLoraTaskId) .def_property_readonly("lora_weights", &LlmRequest::getLoraWeights) .def_property_readonly("lora_config", &LlmRequest::getLoraConfig) .def_property_readonly("lookahead_config", &LlmRequest::getLookaheadConfig) .def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias) .def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList) .def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList) .def_property_readonly("position_ids", [](LlmRequest& self) { return *self.getPositionIds().value_or(std::make_shared>()); }) .def_property_readonly( "context_current_position", py::overload_cast<>(&LlmRequest::getContextCurrentPosition, py::const_)) .def_property("context_chunk_size", &LlmRequest::getContextChunkSize, &LlmRequest::setContextChunkSize) .def_readwrite("request_id", &LlmRequest::mRequestId) .def_readwrite("prompt_len", &LlmRequest::mPromptLen) .def_readwrite("max_new_tokens", &LlmRequest::mMaxNewTokens) .def_readwrite("sampling_config", &LlmRequest::mSamplingConfig) .def_readwrite("state", &LlmRequest::mState) .def_readwrite("is_streaming", &LlmRequest::mIsStreaming) .def_readwrite("end_id", &LlmRequest::mEndId) .def_readwrite("pad_id", &LlmRequest::mPadId) .def_readwrite("seq_slot", &LlmRequest::mSeqSlot) .def_property_readonly("return_log_probs", &LlmRequest::returnLogProbs) .def_property_readonly("return_context_logits", &LlmRequest::setReturnContextLogits) .def_property_readonly("return_generation_logits", &LlmRequest::setReturnGenerationLogits) .def_property_readonly("log_probs", py::overload_cast<>(&LlmRequest::getLogProbs, py::const_)) .def("get_log_probs", py::overload_cast(&LlmRequest::getLogProbs, py::const_)) .def("set_log_probs", &LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam")) .def("set_return_encoder_output", &LlmRequest::setReturnEncoderOutput, py::arg("return_encoder_output")) .def("get_return_encoder_output", &LlmRequest::getReturnEncoderOutput) .def("priority", py::overload_cast<>(&LlmRequest::priority, py::const_)) .def("set_priority", py::overload_cast(&LlmRequest::setPriority)) .def_property_readonly("cum_log_probs", &LlmRequest::getCumLogProbs) .def("set_cum_log_prob", &LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam")) .def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen) .def("has_draft_tokens", &LlmRequest::hasDraftTokens) .def("move_to_next_context_chunk", &LlmRequest::moveToNextContextChunk) .def("is_last_context_chunk", py::overload_cast<>(&LlmRequest::isLastContextChunk, py::const_)) .def("is_first_context_chunk", py::overload_cast<>(&LlmRequest::isFirstContextChunk, py::const_)) .def("get_context_remaining_length", py::overload_cast<>(&LlmRequest::getContextRemainingLength, py::const_)) .def_property( "draft_tokens", [](LlmRequest& self) { return *self.getDraftTokens(); }, [](LlmRequest& self, LlmRequest::VecTokens& draftTokens) { self.setDraftTokens(std::make_shared(std::move(draftTokens))); }) .def_property( "draft_logits", [](LlmRequest& self) { return self.getDraftLogits(); }, [](LlmRequest& self, LlmRequest::TensorPtr& logits) { self.setDraftLogits(std::make_optional(logits)); }) .def_property("num_return_sequences", &LlmRequest::getNumReturnSequences, &LlmRequest::setNumReturnSequences); } void tb::LlmRequestBindings::initBindings(py::module_& m) { py::classh(m, "PyLlmRequest") .def("get_num_tokens", &tb::LlmRequest::getNumTokens, py::arg("beam")) .def_property_readonly("max_beam_num_tokens", &tb::LlmRequest::getMaxBeamNumTokens) .def("get_token", &tb::LlmRequest::getToken, py::arg("beam"), py::arg("pos")) .def("get_tokens", py::overload_cast(&tb::LlmRequest::getTokens, py::const_), py::arg("beam")) .def("get_tokens", py::overload_cast<>(&tb::LlmRequest::getTokens, py::const_)) .def_property_readonly("max_num_generated_tokens", &tb::LlmRequest::getMaxNumGeneratedTokens) .def("add_new_token", &tb::LlmRequest::addNewToken, py::arg("token"), py::arg("beam")) .def("add_new_tokens", &tb::LlmRequest::addNewTokens, py::arg("beam_tokens")) .def("set_generated_tokens", &tb::LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens")) .def("pause", &tb::LlmRequest::pause, py::arg("max_input_len")) .def_property("max_sent_token_len", &tb::LlmRequest::getMaxSentTokenLen, &tb::LlmRequest::setMaxSentTokenLen) .def("prompt_embedding_table", [](tb::LlmRequest& self) { std::optional value{std::nullopt}; auto tensor = self.getPromptEmbeddingTable(); if (tensor) { value = tr::Torch::tensor(*tensor); } return value; }) .def("bad_words_list", [](tb::LlmRequest& self) { std::optional value{std::nullopt}; auto tensor = self.getBadWordsList(); if (tensor) { value = tr::Torch::tensor(*tensor); } return value; }) .def_property( "draft_logits", [](tb::LlmRequest& self) { std::optional value{std::nullopt}; auto tensor = self.getDraftLogits(); if (tensor) { value = tr::Torch::tensor(*tensor); } return value; }, [](tb::LlmRequest& self, at::Tensor& logits) { self.setDraftLogits(std::make_optional(tr::TorchView::of(logits))); }) .def("embedding_bias", [](tb::LlmRequest& self) { std::optional value{std::nullopt}; auto tensor = self.getEmbeddingBias(); if (tensor) { value = tr::Torch::tensor(*tensor); } return value; }) .def("lora_config", [](tb::LlmRequest& self) { std::optional value{std::nullopt}; auto tensor = self.getLoraConfig(); if (tensor) { value = tr::Torch::tensor(*tensor); } return value; }) .def("lora_weights", [](tb::LlmRequest& self) { std::optional value{std::nullopt}; auto tensor = self.getLoraWeights(); if (tensor) { value = tr::Torch::tensor(*tensor); } return value; }) .def("stop_words_list", [](tb::LlmRequest& self) { std::optional value{std::nullopt}; auto tensor = self.getStopWordsList(); if (tensor) { value = tr::Torch::tensor(*tensor); } return value; }) .def_property_readonly("prompt_vocab_size", &tb::LlmRequest::getPromptVocabSize) .def_property_readonly("lora_task_id", &tb::LlmRequest::getLoraTaskId) .def_property_readonly("lookahead_config", &tb::LlmRequest::getLookaheadConfig) .def_property_readonly( "context_current_position", py::overload_cast<>(&tb::LlmRequest::getContextCurrentPosition, py::const_)) .def_property("context_chunk_size", &tb::LlmRequest::getContextChunkSize, &tb::LlmRequest::setContextChunkSize) .def_readwrite("request_id", &tb::LlmRequest::mRequestId) .def_readwrite("prompt_len", &tb::LlmRequest::mPromptLen) .def_readwrite("max_new_tokens", &tb::LlmRequest::mMaxNewTokens) .def_readwrite("sampling_config", &tb::LlmRequest::mSamplingConfig) .def_readwrite("state", &tb::LlmRequest::mState) .def_readwrite("is_streaming", &tb::LlmRequest::mIsStreaming) .def_readwrite("end_id", &tb::LlmRequest::mEndId) .def_readwrite("pad_id", &tb::LlmRequest::mPadId) .def_readwrite("seq_slot", &tb::LlmRequest::mSeqSlot) .def_property_readonly("return_log_probs", &tb::LlmRequest::returnLogProbs) .def_property_readonly("return_context_logits", &tb::LlmRequest::setReturnContextLogits) .def_property_readonly("return_generation_logits", &tb::LlmRequest::setReturnGenerationLogits) .def_property_readonly("log_probs", py::overload_cast<>(&tb::LlmRequest::getLogProbs, py::const_)) .def("get_log_probs", py::overload_cast(&tb::LlmRequest::getLogProbs, py::const_)) .def("set_log_probs", &tb::LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam")) .def("set_return_encoder_output", &tb::LlmRequest::setReturnEncoderOutput, py::arg("return_encoder_output")) .def("get_return_encoder_output", &tb::LlmRequest::getReturnEncoderOutput) .def("priority", py::overload_cast<>(&tb::LlmRequest::priority, py::const_)) .def("set_priority", py::overload_cast(&tb::LlmRequest::setPriority)) .def_property_readonly("cum_log_probs", &tb::LlmRequest::getCumLogProbs) .def("set_cum_log_prob", &tb::LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam")) .def_property_readonly("orig_prompt_len", &tb::LlmRequest::getOrigPromptLen) .def("has_draft_tokens", &tb::LlmRequest::hasDraftTokens) .def("move_to_next_context_chunk", &tb::LlmRequest::moveToNextContextChunk) .def("is_last_context_chunk", py::overload_cast<>(&tb::LlmRequest::isLastContextChunk, py::const_)) .def("is_first_context_chunk", py::overload_cast<>(&tb::LlmRequest::isFirstContextChunk, py::const_)) .def( "get_context_remaining_length", py::overload_cast<>(&tb::LlmRequest::getContextRemainingLength, py::const_)) .def_property( "draft_tokens", [](tb::LlmRequest& self) { return *self.getDraftTokens(); }, [](tb::LlmRequest& self, tb::LlmRequest::VecTokens& draftTokens) { self.setDraftTokens(std::make_shared(std::move(draftTokens))); }); py::bind_vector(m, "RequestVector"); }