/* * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/batch_manager/llmRequest.h" #include #include #include #include #include namespace tensorrt_llm::pybind::batch_manager { namespace tb = tensorrt_llm::batch_manager; /* Unfortunately, torch's default pybind bindings don't know about c10::cuda::CUDAStream, * so we have to pass the more generic c10::Stream, and convert it back to a full-fledged * torch.cuda.Stream in python. See example in test/bindings/test_gpt_manager.py */ class LlmRequest : public tb::GenericLlmRequest { public: using Base = GenericLlmRequest; using TensorPtr = Base::TensorPtr; using SizeType32 = Base::SizeType32; using TokenIdType = Base::TokenIdType; using RequestIdType = Base::RequestIdType; using LoraTaskIdType = Base::LoraTaskIdType; using VecLogProbs = Base::VecLogProbs; using BeamTokens = Base::BeamTokens; using VecTokens = Base::VecTokens; using VecTokenExtraIds = Base::VecTokenExtraIds; using LogitsPostProcessor = Base::LogitsPostProcessor; LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector inputTokens, runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional endId = std::nullopt, std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, std::optional badWordsList = std::nullopt, std::optional stopWordsList = std::nullopt, std::optional> positionIds = std::nullopt, std::optional promptEmbeddingTable = std::nullopt, std::optional promptVocabSize = std::nullopt, std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt, std::optional loraConfig = std::nullopt, std::optional lookaheadConfig = std::nullopt, bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt, bool applyLogitsPostProcessorBatched = false, std::optional encoderInputTokens = std::nullopt, bool returnEncoderOutput = false, std::optional clientId = std::nullopt, executor::PriorityType priority = executor::Request::kDefaultPriority, std::optional encoderInputFeatures = std::nullopt, std::optional encoderOutputLength = std::nullopt, std::optional inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, positionIds.has_value() ? std::make_shared>(std::move(positionIds.value())) : std::optional>>(std::nullopt), promptEmbeddingTable, promptVocabSize, loraTaskId, loraWeights, loraConfig, lookaheadConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, draftTokens.has_value() ? std::make_shared(std::move(draftTokens.value())) : std::make_shared(), draftLogits, excludeInputFromOutput, logitsPostProcessor, applyLogitsPostProcessorBatched, encoderInputTokens ? std::make_optional(std::make_shared(std::move(*encoderInputTokens))) : std::optional>(std::nullopt), returnEncoderOutput, clientId, priority, encoderInputFeatures, encoderOutputLength, tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, inputTokenExtraIds ? std::make_optional(std::make_shared(std::move(*inputTokenExtraIds))) : std::optional>(std::nullopt), numReturnSequences) { } static std::optional callbackAdapter( std::optional callback); [[nodiscard]] std::shared_ptr toTrtLlm() const; static void initBindings(pybind11::module_& m); }; } // namespace tensorrt_llm::pybind::batch_manager