/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/executor/executor.h" #include namespace tensorrt_llm::executor { ExecutorConfig::ExecutorConfig(SizeType32 maxBeamWidth, SchedulerConfig schedulerConfig, KvCacheConfig kvCacheConfig, bool enableChunkedContext, bool normalizeLogProbs, SizeType32 iterStatsMaxIterations, SizeType32 requestStatsMaxIterations, BatchingType batchingType, std::optional maxBatchSize, std::optional maxNumTokens, std::optional parallelConfig, std::optional const& peftCacheConfig, std::optional logitsPostProcessorConfig, std::optional decodingConfig, bool useGpuDirectStorage, float gpuWeightPercent, std::optional maxQueueSize, ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig, std::optional debugConfig, SizeType32 recvPollPeriodMs, uint64_t maxSeqIdleMicroseconds, std::optional specDecConfig, std::optional guidedDecodingConfig, std::optional> additionalModelOutputs, std::optional cacheTransceiverConfig, bool gatherGenerationLogits, bool promptTableOffloading, bool enableTrtOverlap, bool failFastOnAttentionWindowTooLarge) : mMaxBeamWidth(maxBeamWidth) , mSchedulerConfig(std::move(schedulerConfig)) , mKvCacheConfig(std::move(kvCacheConfig)) , mEnableChunkedContext(enableChunkedContext) , mNormalizeLogProbs(normalizeLogProbs) , mIterStatsMaxIterations(iterStatsMaxIterations) , mRequestStatsMaxIterations(requestStatsMaxIterations) , mBatchingType(batchingType) , mMaxBatchSize(maxBatchSize) , mMaxNumTokens(maxNumTokens) , mParallelConfig(std::move(parallelConfig)) , mPeftCacheConfig(peftCacheConfig) , mLogitsPostProcessorConfig(std::move(logitsPostProcessorConfig)) , mDecodingConfig(std::move(decodingConfig)) , mUseGpuDirectStorage((useGpuDirectStorage)) , mGpuWeightsPercent(gpuWeightPercent) , mMaxQueueSize(maxQueueSize) , mExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig) , mDebugConfig(std::move(debugConfig)) , mRecvPollPeriodMs(recvPollPeriodMs) , mMaxSeqIdleMicroseconds(maxSeqIdleMicroseconds) , mSpeculativeDecodingConfig(specDecConfig) , mGuidedDecodingConfig(std::move(guidedDecodingConfig)) , mAdditionalModelOutputs(std::move(additionalModelOutputs)) , mCacheTransceiverConfig(std::move(cacheTransceiverConfig)) , mGatherGenerationLogits(gatherGenerationLogits) , mPromptTableOffloading(promptTableOffloading) , mEnableTrtOverlap(enableTrtOverlap) , mFailFastOnAttentionWindowTooLarge(failFastOnAttentionWindowTooLarge) { TLLM_CHECK(iterStatsMaxIterations >= 0); TLLM_CHECK(requestStatsMaxIterations >= 0); TLLM_CHECK(mMaxBeamWidth > 0); TLLM_CHECK(maxSeqIdleMicroseconds > 0); } // getters SizeType32 ExecutorConfig::getMaxBeamWidth() const { return mMaxBeamWidth; } SchedulerConfig ExecutorConfig::getSchedulerConfig() const { return mSchedulerConfig; } SchedulerConfig& ExecutorConfig::getSchedulerConfigRef() { return mSchedulerConfig; } KvCacheConfig ExecutorConfig::getKvCacheConfig() const { return mKvCacheConfig; } KvCacheConfig& ExecutorConfig::getKvCacheConfigRef() { return mKvCacheConfig; } bool ExecutorConfig::getEnableChunkedContext() const { return mEnableChunkedContext; } bool ExecutorConfig::getNormalizeLogProbs() const { return mNormalizeLogProbs; } SizeType32 ExecutorConfig::getIterStatsMaxIterations() const { return mIterStatsMaxIterations; } SizeType32 ExecutorConfig::getRequestStatsMaxIterations() const { return mRequestStatsMaxIterations; } BatchingType ExecutorConfig::getBatchingType() const { return mBatchingType; } std::optional ExecutorConfig::getMaxBatchSize() const { return mMaxBatchSize; } std::optional ExecutorConfig::getMaxNumTokens() const { return mMaxNumTokens; } std::optional ExecutorConfig::getParallelConfig() const { return mParallelConfig; } std::optional ExecutorConfig::getPeftCacheConfig() const { return mPeftCacheConfig; } std::optional ExecutorConfig::getLogitsPostProcessorConfig() const { return mLogitsPostProcessorConfig; } std::optional ExecutorConfig::getDecodingConfig() const { return mDecodingConfig; } bool ExecutorConfig::getUseGpuDirectStorage() const { return mUseGpuDirectStorage; } float ExecutorConfig::getGpuWeightsPercent() const { return mGpuWeightsPercent; } std::optional ExecutorConfig::getMaxQueueSize() const { return mMaxQueueSize; } ExtendedRuntimePerfKnobConfig ExecutorConfig::getExtendedRuntimePerfKnobConfig() const { return mExtendedRuntimePerfKnobConfig; } std::optional ExecutorConfig::getDebugConfig() const { return mDebugConfig; } SizeType32 ExecutorConfig::getRecvPollPeriodMs() const { return mRecvPollPeriodMs; } uint64_t ExecutorConfig::getMaxSeqIdleMicroseconds() const { return mMaxSeqIdleMicroseconds; } std::optional ExecutorConfig::getSpecDecConfig() const { return mSpeculativeDecodingConfig; } std::optional ExecutorConfig::getGuidedDecodingConfig() const { return mGuidedDecodingConfig; } std::optional> ExecutorConfig::getAdditionalModelOutputs() const { return mAdditionalModelOutputs; } std::optional ExecutorConfig::getCacheTransceiverConfig() const { return mCacheTransceiverConfig; } bool ExecutorConfig::getGatherGenerationLogits() const { return mGatherGenerationLogits; } bool ExecutorConfig::getPromptTableOffloading() const { return mPromptTableOffloading; } bool ExecutorConfig::getEnableTrtOverlap() const { return mEnableTrtOverlap; } bool ExecutorConfig::getFailFastOnAttentionWindowTooLarge() const { return mFailFastOnAttentionWindowTooLarge; } // setters void ExecutorConfig::setMaxBeamWidth(SizeType32 maxBeamWidth) { mMaxBeamWidth = maxBeamWidth; TLLM_CHECK(mMaxBeamWidth > 0); } void ExecutorConfig::setMaxBatchSize(SizeType32 maxBatchSize) { mMaxBatchSize = maxBatchSize; TLLM_CHECK(mMaxBatchSize > 0); } void ExecutorConfig::setMaxNumTokens(SizeType32 maxNumTokens) { mMaxNumTokens = maxNumTokens; TLLM_CHECK(mMaxNumTokens > 0); } void ExecutorConfig::setSchedulerConfig(SchedulerConfig const& schedulerConfig) { mSchedulerConfig = schedulerConfig; } void ExecutorConfig::setKvCacheConfig(KvCacheConfig const& kvCacheConfig) { mKvCacheConfig = kvCacheConfig; } void ExecutorConfig::setEnableChunkedContext(bool enableChunkedContext) { mEnableChunkedContext = enableChunkedContext; } void ExecutorConfig::setNormalizeLogProbs(bool normalizeLogProbs) { mNormalizeLogProbs = normalizeLogProbs; } void ExecutorConfig::setIterStatsMaxIterations(SizeType32 iterStatsMaxIterations) { mIterStatsMaxIterations = iterStatsMaxIterations; TLLM_CHECK(mIterStatsMaxIterations >= 0); } void ExecutorConfig::setRequestStatsMaxIterations(SizeType32 requestStatsMaxIterations) { mRequestStatsMaxIterations = requestStatsMaxIterations; TLLM_CHECK(mRequestStatsMaxIterations >= 0); } void ExecutorConfig::setBatchingType(BatchingType batchingType) { mBatchingType = batchingType; } void ExecutorConfig::setParallelConfig(ParallelConfig const& parallelConfig) { mParallelConfig = parallelConfig; } void ExecutorConfig::setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig) { mPeftCacheConfig = peftCacheConfig; } void ExecutorConfig::setLogitsPostProcessorConfig(LogitsPostProcessorConfig const& logitsPostProcessorConfig) { mLogitsPostProcessorConfig = logitsPostProcessorConfig; } void ExecutorConfig::setDecodingConfig(DecodingConfig const& decodingConfig) { mDecodingConfig = decodingConfig; } void ExecutorConfig::setUseGpuDirectStorage(bool const& useGpuDirectStorage) { mUseGpuDirectStorage = useGpuDirectStorage; } void ExecutorConfig::setGpuWeightsPercent(float const& gpuWeightsPercent) { mGpuWeightsPercent = gpuWeightsPercent; } void ExecutorConfig::setMaxQueueSize(std::optional const& maxQueueSize) { mMaxQueueSize = maxQueueSize; } void ExecutorConfig::setExtendedRuntimePerfKnobConfig( ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig) { mExtendedRuntimePerfKnobConfig = extendedRuntimePerfKnobConfig; } void ExecutorConfig::setDebugConfig(DebugConfig const& debugConfig) { mDebugConfig = debugConfig; } void ExecutorConfig::setRecvPollPeriodMs(SizeType32 const& recvPollPeriodMs) { mRecvPollPeriodMs = recvPollPeriodMs; } void ExecutorConfig::setMaxSeqIdleMicroseconds(uint64_t maxSeqIdleMicroseconds) { mMaxSeqIdleMicroseconds = maxSeqIdleMicroseconds; TLLM_CHECK(mMaxSeqIdleMicroseconds > 0); } void ExecutorConfig::setSpecDecConfig(SpeculativeDecodingConfig const& specDecConfig) { mSpeculativeDecodingConfig = specDecConfig; } void ExecutorConfig::setGuidedDecodingConfig(GuidedDecodingConfig const& guidedDecodingConfig) { mGuidedDecodingConfig = guidedDecodingConfig; } void ExecutorConfig::setAdditionalModelOutputs(std::vector const& additionalModelOutputs) { mAdditionalModelOutputs = additionalModelOutputs; } void ExecutorConfig::setCacheTransceiverConfig(CacheTransceiverConfig const& cacheTransceiverConfig) { mCacheTransceiverConfig = cacheTransceiverConfig; } void ExecutorConfig::setGatherGenerationLogits(bool gatherGenerationLogits) { mGatherGenerationLogits = gatherGenerationLogits; } void ExecutorConfig::setPromptTableOffloading(bool promptTableOffloading) { mPromptTableOffloading = promptTableOffloading; } void ExecutorConfig::setEnableTrtOverlap(bool enableTrtOverlap) { mEnableTrtOverlap = enableTrtOverlap; } void ExecutorConfig::setFailFastOnAttentionWindowTooLarge(bool failFastOnAttentionWindowTooLarge) { mFailFastOnAttentionWindowTooLarge = failFastOnAttentionWindowTooLarge; } } // namespace tensorrt_llm::executor