/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/requestImpl.h" #include "tensorrt_llm/executor/tensor.h" #include "tensorrt_llm/executor/types.h" namespace tensorrt_llm::executor { Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, SamplingConfig const& samplingConfig, OutputConfig const& outputConfig, std::optional const& endId, std::optional const& padId, std::optional> positionIds, std::optional> badWords, std::optional> stopWords, std::optional embeddingBias, std::optional externalDraftTokensConfig, std::optional pTuningConfig, std::optional multimodalInput, std::optional multimodalEmbedding, std::optional mRopeConfig, std::optional loraConfig, std::optional lookaheadConfig, std::optional kvCacheRetentionConfig, std::optional logitsPostProcessorName, std::optional logitslogitsPostProcessor, std::optional encoderInputTokenIds, std::optional clientId, bool returnAllGeneratedTokens, float priority, RequestType type, std::optional contextPhaseParams, std::optional encoderInputFeatures, std::optional encoderOutputLength, std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, std::optional languageAdapterUid, std::optional allottedTimeMs, std::optional cacheSaltID) : mImpl(std::make_unique(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId, padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias), std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput), std::move(multimodalEmbedding), std::move(mRopeConfig), std::move(loraConfig), lookaheadConfig, std::move(kvCacheRetentionConfig), std::move(logitsPostProcessorName), std::move(logitslogitsPostProcessor), std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask, numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, cacheSaltID)) { } Request::~Request() = default; Request::Request(Request const& other) : mImpl(std::make_unique(*other.mImpl)) { } Request::Request(Request&& other) noexcept = default; Request& Request::operator=(Request const& other) { if (this != &other) { mImpl = std::make_unique(*other.mImpl); } return *this; } Request& Request::operator=(Request&& other) noexcept = default; VecTokens Request::getInputTokenIds() const { return mImpl->getInputTokenIds(); } SizeType32 Request::getMaxTokens() const { return mImpl->getMaxNewTokens(); } bool Request::getStreaming() const { return mImpl->getStreaming(); } SamplingConfig Request::getSamplingConfig() const { return mImpl->getSamplingConfig(); } OutputConfig Request::getOutputConfig() const { return mImpl->getOutputConfig(); } std::optional Request::getEndId() const { return mImpl->getEndId(); } std::optional Request::getPadId() const { return mImpl->getPadId(); } std::optional> Request::getPositionIds() const { return mImpl->getPositionIds(); } std::optional> Request::getBadWords() const { return mImpl->getBadWords(); } std::optional> Request::getStopWords() const { return mImpl->getStopWords(); } std::optional Request::getEmbeddingBias() const { return mImpl->getEmbeddingBias(); } std::optional Request::getExternalDraftTokensConfig() const { return mImpl->getExternalDraftTokensConfig(); } std::optional Request::getPromptTuningConfig() const { return mImpl->getPromptTuningConfig(); } std::optional Request::getMultimodalEmbedding() const { return mImpl->getMultimodalEmbedding(); } std::optional Request::getMultimodalInput() const { return mImpl->getMultimodalInput(); } std::optional Request::getMropeConfig() const { return mImpl->getMropeConfig(); } std::optional Request::getLoraConfig() const { return mImpl->getLoraConfig(); } std::optional Request::getLookaheadConfig() const { return mImpl->getLookaheadConfig(); } std::optional Request::getKvCacheRetentionConfig() const { return mImpl->getKvCacheRetentionConfig(); } std::optional Request::getLogitsPostProcessorName() const { return mImpl->getLogitsPostProcessorName(); } std::optional Request::getLogitsPostProcessor() const { return mImpl->getLogitsPostProcessor(); } std::optional Request::getEncoderInputTokenIds() const { return mImpl->getEncoderInputTokenIds(); } std::optional Request::getClientId() const { return mImpl->getClientId(); } PriorityType Request::getPriority() const { return mImpl->getPriority(); } std::optional Request::getAllottedTimeMs() const { return mImpl->getAllottedTimeMs(); } bool Request::getReturnAllGeneratedTokens() const { return mImpl->getReturnAllGeneratedTokens(); } RequestType Request::getRequestType() const { return mImpl->getRequestType(); } std::optional const& Request::getContextPhaseParams() const { return mImpl->getContextPhaseParams(); } std::optional Request::getEncoderInputFeatures() const { return mImpl->getEncoderInputFeatures(); } std::optional Request::getEncoderOutputLength() const { return mImpl->getEncoderOutputLength(); } std::optional Request::getCrossAttentionMask() const { return mImpl->getCrossAttentionMask(); } std::optional Request::getEagleConfig() const { return mImpl->getEagleConfig(); } std::optional Request::getSkipCrossAttnBlocks() const { return mImpl->getSkipCrossAttnBlocks(); } std::optional Request::getGuidedDecodingParams() const { return mImpl->getGuidedDecodingParams(); } std::optional Request::getLanguageAdapterUid() const { return mImpl->getLanguageAdapterUid(); } std::optional Request::getCacheSaltID() const { return mImpl->getCacheSaltID(); } void Request::setStreaming(bool streaming) { mImpl->setStreaming(streaming); } void Request::setSamplingConfig(SamplingConfig const& config) { mImpl->setSamplingConfig(config); } void Request::setOutputConfig(OutputConfig const& outputConfig) { mImpl->setOutputConfig(outputConfig); } void Request::setEndId(SizeType32 endId) { mImpl->setEndId(endId); } void Request::setPadId(SizeType32 padId) { mImpl->setPadId(padId); } void Request::setPositionIds(std::vector const& positionIds) { mImpl->setPositionIds(positionIds); } void Request::setBadWords(std::list const& badWords) { mImpl->setBadWords(badWords); } void Request::setStopWords(std::list const& stopWords) { mImpl->setStopWords(stopWords); } void Request::setEmbeddingBias(Tensor const& embeddingBias) { mImpl->setEmbeddingBias(embeddingBias); } void Request::setExternalDraftTokensConfig(ExternalDraftTokensConfig const& specDecodingConfig) { mImpl->setExternalDraftTokensConfig(specDecodingConfig); } void Request::setPromptTuningConfig(PromptTuningConfig const& pTuningConfig) { mImpl->setPromptTuningConfig(pTuningConfig); } void Request::setMultimodalEmbedding(Tensor const& multimodalEmbedding) { return mImpl->setMultimodalEmbedding(multimodalEmbedding); } void Request::setMultimodalInput(MultimodalInput const& multimodalInput) { return mImpl->setMultimodalInput(multimodalInput); } void Request::setMropeConfig(MropeConfig const& mRopeConfig) { mImpl->setMropeConfig(mRopeConfig); } void Request::setLoraConfig(LoraConfig const& loraConfig) { mImpl->setLoraConfig(loraConfig); } void Request::setLookaheadConfig(LookaheadDecodingConfig const& lookaheadConfig) { mImpl->setLookaheadConfig(lookaheadConfig); } void Request::setKvCacheRetentionConfig(KvCacheRetentionConfig const& kvCacheRetentionConfig) { mImpl->setKvCacheRetentionConfig(kvCacheRetentionConfig); } void Request::setLogitsPostProcessorName(std::string const& logitsPostProcessorName) { mImpl->setLogitsPostProcessorName(logitsPostProcessorName); } void Request::setLogitsPostProcessor(std::optional const& logitsPostProcessor) { mImpl->setLogitsPostProcessor(logitsPostProcessor); } void Request::setEncoderInputTokenIds(VecTokens const& encoderInputTokenIds) { mImpl->setEncoderInputTokenIds(encoderInputTokenIds); } void Request::setClientId(IdType clientId) { mImpl->setClientId(clientId); } void Request::setPriority(PriorityType priority) { mImpl->setPriority(priority); } void Request::setReturnAllGeneratedTokens(bool returnAllGeneratedTokens) { mImpl->setReturnAllGeneratedTokens(returnAllGeneratedTokens); } void Request::setRequestType(RequestType const& requestType) { mImpl->setRequestType(requestType); } void Request::setContextPhaseParams(ContextPhaseParams contextPhaseParams) { mImpl->setContextPhaseParams(std::move(contextPhaseParams)); } void Request::setEncoderInputFeatures(Tensor encoderInputFeatures) { mImpl->setEncoderInputFeatures(encoderInputFeatures); } void Request::setEncoderOutputLength(SizeType32 encoderOutputLength) { mImpl->setEncoderOutputLength(encoderOutputLength); } void Request::setCrossAttentionMask(Tensor crossAttentionMask) { mImpl->setCrossAttentionMask(crossAttentionMask); } void Request::setEagleConfig(std::optional const& eagleConfig) { mImpl->setEagleConfig(eagleConfig); } void Request::setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks) { return mImpl->setSkipCrossAttnBlocks(skipCrossAttnBlocks); } void Request::setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams) { mImpl->setGuidedDecodingParams(guidedDecodingParams); } void Request::setAllottedTimeMs(MillisecondsType allottedTimeMs) { return mImpl->setAllottedTimeMs(allottedTimeMs); } void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid) { return mImpl->setLanguageAdapterUid(languageAdapterUid); } void Request::setCacheSaltID(CacheSaltIDType cacheSaltID) { return mImpl->setCacheSaltID(cacheSaltID); } } // namespace tensorrt_llm::executor