/* * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/batch_manager/llmRequest.h" #include #include #include #include #include namespace nb = nanobind; namespace tensorrt_llm::nanobind::batch_manager { namespace tb = tensorrt_llm::batch_manager; /* Unfortunately, torch's default nanobind bindings don't know about c10::cuda::CUDAStream, * so we have to pass the more generic c10::Stream, and convert it back to a full-fledged * torch.cuda.Stream in python. See example in test/bindings/test_gpt_manager.py */ class LlmRequest : public tb::GenericLlmRequest { public: using Base = GenericLlmRequest; using TensorPtr = Base::TensorPtr; using SizeType32 = Base::SizeType32; using TokenIdType = Base::TokenIdType; using RequestIdType = Base::RequestIdType; using LoraTaskIdType = Base::LoraTaskIdType; using VecLogProbs = Base::VecLogProbs; using BeamTokens = Base::BeamTokens; using VecTokens = Base::VecTokens; using VecTokenExtraIds = Base::VecTokenExtraIds; using LogitsPostProcessor = Base::LogitsPostProcessor; LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector inputTokens, runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional endId = std::nullopt, std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, std::optional badWordsList = std::nullopt, std::optional stopWordsList = std::nullopt, std::optional> positionIds = std::nullopt, std::optional promptEmbeddingTable = std::nullopt, std::optional promptVocabSize = std::nullopt, std::optional>> multimodalHashes = std::nullopt, std::optional> multimodalPositions = std::nullopt, std::optional> multimodalLengths = std::nullopt, std::optional multimodalEmbedding = std::nullopt, std::optional mropeRotaryCosSin = std::nullopt, std::optional mropePositionDeltas = std::nullopt, std::optional loraTaskId = std::nullopt, std::optional loraWeights = std::nullopt, std::optional loraConfig = std::nullopt, std::optional lookaheadConfig = std::nullopt, std::optional kvCacheRetentionConfig = std::nullopt, bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, std::optional draftTokens = std::nullopt, std::optional draftLogits = std::nullopt, bool excludeInputFromOutput = false, std::optional logitsPostProcessor = std::nullopt, bool applyLogitsPostProcessorBatched = false, std::optional encoderInputTokens = std::nullopt, bool returnEncoderOutput = false, std::optional clientId = std::nullopt, executor::PriorityType priority = executor::Request::kDefaultPriority, std::optional encoderInputFeatures = std::nullopt, std::optional encoderOutputLength = std::nullopt, std::optional crossAttentionMask = std::nullopt, tb::LlmRequestType llmRequestType = tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::optional inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1, std::optional eagleConfig = std::nullopt, std::optional skipCrossAttnBlocks = std::nullopt, bool returnPerfMetrics = false, std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, std::optional const& contextPhaseParams = std::nullopt, std::optional cacheSaltID = std::nullopt, std::optional arrivalTime = std::nullopt) : Base(requestId, // maxNewTokens, // std::make_shared>(std::move(inputTokens)), // samplingConfig, // isStreaming, // endId, // padId, // embeddingBias, // badWordsList, // stopWordsList, // positionIds.has_value() ? std::make_shared>(std::move(positionIds.value())) // : std::optional>>(std::nullopt), // promptEmbeddingTable, // promptVocabSize, // multimodalHashes.has_value() ? std::make_optional( std::make_shared>>(std::move(multimodalHashes.value()))) // : std::optional>>>(std::nullopt), // multimodalPositions.has_value() ? std::make_shared>(std::move(multimodalPositions.value())) // : std::optional>>(std::nullopt), // multimodalLengths.has_value() ? std::make_shared>(std::move(multimodalLengths.value())) // : std::optional>>(std::nullopt), // multimodalEmbedding, // mropeRotaryCosSin, // mropePositionDeltas, // loraTaskId, // loraWeights, // loraConfig, // lookaheadConfig, // kvCacheRetentionConfig, // returnLogProbs, // returnContextLogits, // returnGenerationLogits, // draftTokens.has_value() ? std::make_shared(std::move(draftTokens.value())) // : std::make_shared(), // draftLogits, // excludeInputFromOutput, // logitsPostProcessor, // applyLogitsPostProcessorBatched, // encoderInputTokens ? std::make_optional(std::make_shared(std::move(*encoderInputTokens))) // : std::optional>(std::nullopt), // returnEncoderOutput, // clientId, // priority, // encoderInputFeatures, // encoderOutputLength, // crossAttentionMask, // llmRequestType, // inputTokenExtraIds // ? std::make_optional(std::make_shared(std::move(*inputTokenExtraIds))) // : std::optional>(std::nullopt), // numReturnSequences, // eagleConfig, // skipCrossAttnBlocks, // returnPerfMetrics, // guidedDecodingParams, // languageAdapterUid, // allottedTimeMs, // contextPhaseParams, // cacheSaltID, // arrivalTime // ) { } static std::optional callbackAdapter( std::optional callback); [[nodiscard]] std::shared_ptr toTrtLlm() const; }; } // namespace tensorrt_llm::nanobind::batch_manager