/* * SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/common/assert.h" #include #include #include #include namespace tensorrt_llm::pybind::batch_manager { class LlmRequest : public tensorrt_llm::batch_manager::GenericLlmRequest { public: using Base = GenericLlmRequest; using TensorPtr = Base::TensorPtr; using SizeType = Base::SizeType; using TokenIdType = Base::TokenIdType; using RequestIdType = Base::RequestIdType; using VecLogProbs = Base::VecLogProbs; using BeamTokens = Base::BeamTokens; using VecTokens = Base::VecTokens; LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::vector inputTokens, runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional endId = std::nullopt, std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, std::optional badWordsList = std::nullopt, std::optional stopWordsList = std::nullopt, std::optional promptEmbeddingTable = std::nullopt, std::optional promptVocabSize = std::nullopt, bool returnLogProbs = false, std::optional draftTokens = std::nullopt, std::optional draftLogits = std::nullopt) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, returnLogProbs, draftTokens.has_value() ? std::make_shared(std::move(draftTokens.value())) : std::make_shared(), draftLogits) { } [[nodiscard]] std::shared_ptr toTrtLlm() const; }; } // namespace tensorrt_llm::pybind::batch_manager