mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: Bhuvanesh Sridharan <bhuvanesh.sridharan@sprinklr.com> Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
326 lines
18 KiB
C++
326 lines
18 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include "llmRequest.h"
|
|
|
|
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
|
#include "tensorrt_llm/pybind/utils/bindTypes.h"
|
|
#include "tensorrt_llm/runtime/torch.h"
|
|
#include "tensorrt_llm/runtime/torchUtils.h"
|
|
#include "tensorrt_llm/runtime/torchView.h"
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <pybind11/functional.h>
|
|
#include <pybind11/operators.h>
|
|
#include <pybind11/stl.h>
|
|
#include <pybind11/stl_bind.h>
|
|
#include <torch/extension.h>
|
|
|
|
#include <memory>
|
|
|
|
namespace tb = tensorrt_llm::batch_manager;
|
|
namespace tr = tensorrt_llm::runtime;
|
|
namespace tle = tensorrt_llm::executor;
|
|
|
|
using namespace tensorrt_llm::pybind::batch_manager;
|
|
|
|
using LlmRequestPtr = std::shared_ptr<tb::LlmRequest>;
|
|
using RequestList = std::list<LlmRequestPtr>;
|
|
|
|
namespace
|
|
{
|
|
|
|
std::optional<tb::LlmRequest::TensorPtr> from_torch(std::optional<LlmRequest::TensorPtr> torchPtr)
|
|
{
|
|
if (torchPtr)
|
|
{
|
|
return tr::TorchView::of(torchPtr.value());
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::optional<tb::LlmRequest::LogitsPostProcessor> LlmRequest::callbackAdapter(
|
|
std::optional<LlmRequest::LogitsPostProcessor> callback)
|
|
{
|
|
if (!callback)
|
|
{
|
|
return std::nullopt;
|
|
}
|
|
|
|
return [callback](RequestIdType reqId, tensorrt_llm::runtime::ITensor::SharedPtr& tensor,
|
|
tensorrt_llm::batch_manager::LlmRequest::BeamTokens const& tokens,
|
|
tensorrt_llm::runtime::BufferManager::CudaStreamPtr stream, std::optional<RequestIdType> clientId)
|
|
{
|
|
at::Tensor atTensor = tr::Torch::tensor(tensor);
|
|
callback.value()(reqId, atTensor, tokens, runtime::TorchUtils::stream(*stream).unwrap(), clientId);
|
|
};
|
|
}
|
|
|
|
std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
|
{
|
|
auto embeddingBias = from_torch(mEmbeddingBias);
|
|
auto badWordsList = from_torch(mBadWordsList);
|
|
auto stopWordsList = from_torch(mStopWordsList);
|
|
auto promptEmbeddingTable = from_torch(mPromptEmbeddingTable);
|
|
auto loraWeights = from_torch(mLoraWeights);
|
|
auto loraConfig = from_torch(mLoraConfig);
|
|
auto draftLogits = from_torch(mDraftLogits);
|
|
auto encoderInputFeatures = from_torch(mEncoderInputFeatures);
|
|
|
|
return std::make_shared<tb::LlmRequest>(mRequestId, mMaxNewTokens,
|
|
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId,
|
|
embeddingBias, badWordsList, stopWordsList, mPositionIds, promptEmbeddingTable, mPromptVocabSize, mLoraTaskId,
|
|
loraWeights, loraConfig, mLookaheadConfig, returnLogProbs(), mReturnContextLogits, mReturnGenerationLogits,
|
|
mDraftTokens, draftLogits, mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor),
|
|
mApplyLogitsPostProcessorBatched, mEncoderTokens, mReturnEncoderOutput, mClientId, mPriority,
|
|
encoderInputFeatures, mEncoderOutputLength, tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
|
|
mInputTokenExtraIds, mNumReturnSequences);
|
|
}
|
|
|
|
void LlmRequest::initBindings(py::module_& m)
|
|
{
|
|
py::class_<LlmRequest>(m, "LlmRequest")
|
|
.def(py::init<LlmRequest::RequestIdType, LlmRequest::SizeType32, LlmRequest::VecTokens, tr::SamplingConfig,
|
|
bool, std::optional<LlmRequest::SizeType32>, std::optional<LlmRequest::SizeType32>,
|
|
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
|
|
std::optional<LlmRequest::TensorPtr>, std::optional<std::vector<LlmRequest::SizeType32>>,
|
|
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::SizeType32>, std::optional<uint64_t>,
|
|
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
|
|
std::optional<executor::LookaheadDecodingConfig>, bool, bool, bool,
|
|
std::optional<LlmRequest::VecTokens>, std::optional<LlmRequest::TensorPtr>, bool,
|
|
std::optional<LlmRequest::LogitsPostProcessor>, bool, std::optional<LlmRequest::VecTokens>, bool,
|
|
std::optional<RequestIdType>, executor::PriorityType, std::optional<LlmRequest::TensorPtr>,
|
|
std::optional<LlmRequest::SizeType32>, std::optional<LlmRequest::VecTokenExtraIds>,
|
|
LlmRequest::SizeType32>(),
|
|
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
|
|
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
|
|
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
|
|
py::arg("stop_words_list") = std::nullopt, py::arg("position_ids") = std::nullopt,
|
|
py::arg("prompt_embedding_table") = std::nullopt, py::arg("prompt_vocab_size") = std::nullopt,
|
|
py::arg("lora_task_id") = std::nullopt, py::arg("lora_weights") = std::nullopt,
|
|
py::arg("lora_config") = std::nullopt, py::arg("lookahead_config") = std::nullopt,
|
|
py::arg("return_log_probs") = false, py::arg("return_context_logits") = false,
|
|
py::arg("return_generation_logits") = false, py::arg("draft_tokens") = std::nullopt,
|
|
py::arg("draft_logits") = std::nullopt, py::arg("exclude_input_from_output") = false,
|
|
py::arg("logits_post_processor") = std::nullopt, py::arg("apply_logits_post_processor_batched") = false,
|
|
py::arg("encoder_input_tokens") = std::nullopt, py::arg("return_encoder_output") = false,
|
|
py::arg("client_id") = std::nullopt, py::arg("priority") = executor::Request::kDefaultPriority,
|
|
py::arg("encoder_input_features") = std::nullopt, py::arg("encoder_output_length") = std::nullopt,
|
|
py::arg("input_token_extra_ids") = std::nullopt, py::arg("num_return_sequences") = 1)
|
|
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
|
|
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
|
|
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
|
|
.def("get_tokens", py::overload_cast<LlmRequest::SizeType32>(&LlmRequest::getTokens, py::const_),
|
|
py::arg("beam"))
|
|
.def("get_tokens", py::overload_cast<>(&LlmRequest::getTokens, py::const_))
|
|
.def_property_readonly("max_num_generated_tokens", &LlmRequest::getMaxNumGeneratedTokens)
|
|
.def("add_new_token", &LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
|
|
.def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens"))
|
|
.def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
|
|
.def("pause", &LlmRequest::pause, py::arg("max_input_len"))
|
|
.def_property("max_sent_token_len", &LlmRequest::getMaxSentTokenLen, &LlmRequest::setMaxSentTokenLen)
|
|
.def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable)
|
|
.def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize)
|
|
.def_property_readonly("lora_task_id", &LlmRequest::getLoraTaskId)
|
|
.def_property_readonly("lora_weights", &LlmRequest::getLoraWeights)
|
|
.def_property_readonly("lora_config", &LlmRequest::getLoraConfig)
|
|
.def_property_readonly("lookahead_config", &LlmRequest::getLookaheadConfig)
|
|
.def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias)
|
|
.def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList)
|
|
.def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList)
|
|
.def_property_readonly("position_ids",
|
|
[](LlmRequest& self)
|
|
{ return *self.getPositionIds().value_or(std::make_shared<std::vector<SizeType32>>()); })
|
|
.def_property_readonly(
|
|
"context_current_position", py::overload_cast<>(&LlmRequest::getContextCurrentPosition, py::const_))
|
|
.def_property("context_chunk_size", &LlmRequest::getContextChunkSize, &LlmRequest::setContextChunkSize)
|
|
.def_readwrite("request_id", &LlmRequest::mRequestId)
|
|
.def_readwrite("prompt_len", &LlmRequest::mPromptLen)
|
|
.def_readwrite("max_new_tokens", &LlmRequest::mMaxNewTokens)
|
|
.def_readwrite("sampling_config", &LlmRequest::mSamplingConfig)
|
|
.def_readwrite("state", &LlmRequest::mState)
|
|
.def_readwrite("is_streaming", &LlmRequest::mIsStreaming)
|
|
.def_readwrite("end_id", &LlmRequest::mEndId)
|
|
.def_readwrite("pad_id", &LlmRequest::mPadId)
|
|
.def_readwrite("seq_slot", &LlmRequest::mSeqSlot)
|
|
.def_property_readonly("return_log_probs", &LlmRequest::returnLogProbs)
|
|
.def_property_readonly("return_context_logits", &LlmRequest::setReturnContextLogits)
|
|
.def_property_readonly("return_generation_logits", &LlmRequest::setReturnGenerationLogits)
|
|
.def_property_readonly("log_probs", py::overload_cast<>(&LlmRequest::getLogProbs, py::const_))
|
|
.def("get_log_probs", py::overload_cast<SizeType32>(&LlmRequest::getLogProbs, py::const_))
|
|
.def("set_log_probs", &LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
|
|
.def("set_return_encoder_output", &LlmRequest::setReturnEncoderOutput, py::arg("return_encoder_output"))
|
|
.def("get_return_encoder_output", &LlmRequest::getReturnEncoderOutput)
|
|
.def("priority", py::overload_cast<>(&LlmRequest::priority, py::const_))
|
|
.def("set_priority", py::overload_cast<executor::PriorityType>(&LlmRequest::setPriority))
|
|
.def_property_readonly("cum_log_probs", &LlmRequest::getCumLogProbs)
|
|
.def("set_cum_log_prob", &LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
|
|
.def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen)
|
|
.def("has_draft_tokens", &LlmRequest::hasDraftTokens)
|
|
.def("move_to_next_context_chunk", &LlmRequest::moveToNextContextChunk)
|
|
.def("is_last_context_chunk", py::overload_cast<>(&LlmRequest::isLastContextChunk, py::const_))
|
|
.def("is_first_context_chunk", py::overload_cast<>(&LlmRequest::isFirstContextChunk, py::const_))
|
|
.def("get_context_remaining_length", py::overload_cast<>(&LlmRequest::getContextRemainingLength, py::const_))
|
|
.def_property(
|
|
"draft_tokens", [](LlmRequest& self) { return *self.getDraftTokens(); },
|
|
[](LlmRequest& self, LlmRequest::VecTokens& draftTokens)
|
|
{ self.setDraftTokens(std::make_shared<LlmRequest::VecTokens>(std::move(draftTokens))); })
|
|
.def_property(
|
|
"draft_logits", [](LlmRequest& self) { return self.getDraftLogits(); },
|
|
[](LlmRequest& self, LlmRequest::TensorPtr& logits)
|
|
{ self.setDraftLogits(std::make_optional<LlmRequest::TensorPtr>(logits)); })
|
|
.def_property("num_return_sequences", &LlmRequest::getNumReturnSequences, &LlmRequest::setNumReturnSequences);
|
|
}
|
|
|
|
void tb::LlmRequestBindings::initBindings(py::module_& m)
|
|
{
|
|
py::classh<tb::LlmRequest>(m, "PyLlmRequest")
|
|
.def("get_num_tokens", &tb::LlmRequest::getNumTokens, py::arg("beam"))
|
|
.def_property_readonly("max_beam_num_tokens", &tb::LlmRequest::getMaxBeamNumTokens)
|
|
.def("get_token", &tb::LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
|
|
.def("get_tokens", py::overload_cast<tb::LlmRequest::SizeType32>(&tb::LlmRequest::getTokens, py::const_),
|
|
py::arg("beam"))
|
|
.def("get_tokens", py::overload_cast<>(&tb::LlmRequest::getTokens, py::const_))
|
|
.def_property_readonly("max_num_generated_tokens", &tb::LlmRequest::getMaxNumGeneratedTokens)
|
|
.def("add_new_token", &tb::LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
|
|
.def("add_new_tokens", &tb::LlmRequest::addNewTokens, py::arg("beam_tokens"))
|
|
.def("set_generated_tokens", &tb::LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
|
|
.def("pause", &tb::LlmRequest::pause, py::arg("max_input_len"))
|
|
.def_property("max_sent_token_len", &tb::LlmRequest::getMaxSentTokenLen, &tb::LlmRequest::setMaxSentTokenLen)
|
|
.def("prompt_embedding_table",
|
|
[](tb::LlmRequest& self)
|
|
{
|
|
std::optional<at::Tensor> value{std::nullopt};
|
|
auto tensor = self.getPromptEmbeddingTable();
|
|
if (tensor)
|
|
{
|
|
value = tr::Torch::tensor(*tensor);
|
|
}
|
|
return value;
|
|
})
|
|
.def("bad_words_list",
|
|
[](tb::LlmRequest& self)
|
|
{
|
|
std::optional<at::Tensor> value{std::nullopt};
|
|
auto tensor = self.getBadWordsList();
|
|
if (tensor)
|
|
{
|
|
value = tr::Torch::tensor(*tensor);
|
|
}
|
|
return value;
|
|
})
|
|
.def_property(
|
|
"draft_logits",
|
|
[](tb::LlmRequest& self)
|
|
{
|
|
std::optional<at::Tensor> value{std::nullopt};
|
|
auto tensor = self.getDraftLogits();
|
|
if (tensor)
|
|
{
|
|
value = tr::Torch::tensor(*tensor);
|
|
}
|
|
return value;
|
|
},
|
|
[](tb::LlmRequest& self, at::Tensor& logits)
|
|
{ self.setDraftLogits(std::make_optional<tb::LlmRequest::TensorPtr>(tr::TorchView::of(logits))); })
|
|
.def("embedding_bias",
|
|
[](tb::LlmRequest& self)
|
|
{
|
|
std::optional<at::Tensor> value{std::nullopt};
|
|
auto tensor = self.getEmbeddingBias();
|
|
if (tensor)
|
|
{
|
|
value = tr::Torch::tensor(*tensor);
|
|
}
|
|
return value;
|
|
})
|
|
.def("lora_config",
|
|
[](tb::LlmRequest& self)
|
|
{
|
|
std::optional<at::Tensor> value{std::nullopt};
|
|
auto tensor = self.getLoraConfig();
|
|
if (tensor)
|
|
{
|
|
value = tr::Torch::tensor(*tensor);
|
|
}
|
|
return value;
|
|
})
|
|
.def("lora_weights",
|
|
[](tb::LlmRequest& self)
|
|
{
|
|
std::optional<at::Tensor> value{std::nullopt};
|
|
auto tensor = self.getLoraWeights();
|
|
if (tensor)
|
|
{
|
|
value = tr::Torch::tensor(*tensor);
|
|
}
|
|
return value;
|
|
})
|
|
.def("stop_words_list",
|
|
[](tb::LlmRequest& self)
|
|
{
|
|
std::optional<at::Tensor> value{std::nullopt};
|
|
auto tensor = self.getStopWordsList();
|
|
if (tensor)
|
|
{
|
|
value = tr::Torch::tensor(*tensor);
|
|
}
|
|
return value;
|
|
})
|
|
.def_property_readonly("prompt_vocab_size", &tb::LlmRequest::getPromptVocabSize)
|
|
.def_property_readonly("lora_task_id", &tb::LlmRequest::getLoraTaskId)
|
|
.def_property_readonly("lookahead_config", &tb::LlmRequest::getLookaheadConfig)
|
|
.def_property_readonly(
|
|
"context_current_position", py::overload_cast<>(&tb::LlmRequest::getContextCurrentPosition, py::const_))
|
|
.def_property("context_chunk_size", &tb::LlmRequest::getContextChunkSize, &tb::LlmRequest::setContextChunkSize)
|
|
.def_readwrite("request_id", &tb::LlmRequest::mRequestId)
|
|
.def_readwrite("prompt_len", &tb::LlmRequest::mPromptLen)
|
|
.def_readwrite("max_new_tokens", &tb::LlmRequest::mMaxNewTokens)
|
|
.def_readwrite("sampling_config", &tb::LlmRequest::mSamplingConfig)
|
|
.def_readwrite("state", &tb::LlmRequest::mState)
|
|
.def_readwrite("is_streaming", &tb::LlmRequest::mIsStreaming)
|
|
.def_readwrite("end_id", &tb::LlmRequest::mEndId)
|
|
.def_readwrite("pad_id", &tb::LlmRequest::mPadId)
|
|
.def_readwrite("seq_slot", &tb::LlmRequest::mSeqSlot)
|
|
.def_property_readonly("return_log_probs", &tb::LlmRequest::returnLogProbs)
|
|
.def_property_readonly("return_context_logits", &tb::LlmRequest::setReturnContextLogits)
|
|
.def_property_readonly("return_generation_logits", &tb::LlmRequest::setReturnGenerationLogits)
|
|
.def_property_readonly("log_probs", py::overload_cast<>(&tb::LlmRequest::getLogProbs, py::const_))
|
|
.def("get_log_probs", py::overload_cast<tb::LlmRequest::SizeType32>(&tb::LlmRequest::getLogProbs, py::const_))
|
|
.def("set_log_probs", &tb::LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
|
|
.def("set_return_encoder_output", &tb::LlmRequest::setReturnEncoderOutput, py::arg("return_encoder_output"))
|
|
.def("get_return_encoder_output", &tb::LlmRequest::getReturnEncoderOutput)
|
|
.def("priority", py::overload_cast<>(&tb::LlmRequest::priority, py::const_))
|
|
.def("set_priority", py::overload_cast<tle::PriorityType>(&tb::LlmRequest::setPriority))
|
|
.def_property_readonly("cum_log_probs", &tb::LlmRequest::getCumLogProbs)
|
|
.def("set_cum_log_prob", &tb::LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
|
|
.def_property_readonly("orig_prompt_len", &tb::LlmRequest::getOrigPromptLen)
|
|
.def("has_draft_tokens", &tb::LlmRequest::hasDraftTokens)
|
|
.def("move_to_next_context_chunk", &tb::LlmRequest::moveToNextContextChunk)
|
|
.def("is_last_context_chunk", py::overload_cast<>(&tb::LlmRequest::isLastContextChunk, py::const_))
|
|
.def("is_first_context_chunk", py::overload_cast<>(&tb::LlmRequest::isFirstContextChunk, py::const_))
|
|
.def(
|
|
"get_context_remaining_length", py::overload_cast<>(&tb::LlmRequest::getContextRemainingLength, py::const_))
|
|
.def_property(
|
|
"draft_tokens", [](tb::LlmRequest& self) { return *self.getDraftTokens(); },
|
|
[](tb::LlmRequest& self, tb::LlmRequest::VecTokens& draftTokens)
|
|
{ self.setDraftTokens(std::make_shared<tb::LlmRequest::VecTokens>(std::move(draftTokens))); });
|
|
|
|
py::bind_vector<tb::RequestVector>(m, "RequestVector");
|
|
}
|