/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/batch_manager/capacityScheduler.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/batch_manager/peftCacheManager.h" #include "tensorrt_llm/batch_manager/scheduledBlocksManager.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/nvtxUtils.h" namespace tensorrt_llm::batch_manager { using kv_cache_manager::VecUniqueTokens; using kv_cache_manager::BlockKey; using kv_cache_manager::BlockKeyHasher; namespace { std::tuple, std::unordered_set> prefillWithChunkedContextsAlreadyExecuting(RequestList const& activeRequests, kv_cache_manager::BaseKVCacheManager const& kvCacheManager, OptionalRef crossKvCacheManager = std::nullopt) { std::unordered_set newlyContributedContextBlocks; std::unordered_set newlyContributedCrossContextBlocks; for (auto const& req : activeRequests) { if (req->isContextInitState() && !req->isFirstContextChunk()) { // Chunked context request already executing, but haven't completed all chunks yet. // Skipping is not an option, register it's contributed blocks if (kvCacheManager.isEnableBlockReuse()) { auto uniqueTokens = req->getUniqueTokens(0); auto newContextBlockOpt = kvCacheManager.findNewContextBlock(uniqueTokens, *req); if (newContextBlockOpt.has_value()) { newlyContributedContextBlocks.insert(newContextBlockOpt.value()); } } if (crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()) { auto uniqueTokens = *(req->getEncoderUniqueTokens().value()); auto newContextBlockOpt = crossKvCacheManager->findNewContextBlock(uniqueTokens, *req); if (newContextBlockOpt.has_value()) { newlyContributedCrossContextBlocks.insert(newContextBlockOpt.value()); } } } } return {std::move(newlyContributedContextBlocks), std::move(newlyContributedCrossContextBlocks)}; } bool oneManagerBeneficialToSkip(tensorrt_llm::batch_manager::kv_cache_manager::BaseKVCacheManager const& kvCacheManager, VecUniqueTokens const& uniqueTokens, std::shared_ptr const& llmRequest, std::unordered_set& newlyContributedContextBlocks) { // Find first context block that isn't already in KV cache auto newContextBlockOpt = kvCacheManager.findNewContextBlock(uniqueTokens, *llmRequest); if (newContextBlockOpt.has_value()) { auto const& newContextBlock = newContextBlockOpt.value(); if (newlyContributedContextBlocks.count(newContextBlock) > 0) { // newContextBlock was contributed by earlier scheduled request. // Better to skip this step so we can reuse. return true; } // This request is contributing newContextBlock. newlyContributedContextBlocks.insert(newContextBlock); } // Either all context blocks are already in KV cache, // or no previously scheduled request has contributed newContextBlock. return false; } //! \brief Check if it is beneficial to skip this request rather than schedule it. //! \details One condition that makes it beneficial is if this request can reuse kv cache block(s) contributed by //! already scheduled context requests. bool beneficialToSkip(std::shared_ptr const& req, kv_cache_manager::BaseKVCacheManager const& kvCacheManager, OptionalRef crossKvCacheManager, std::unordered_set& newlyContributedContextBlocks, std::unordered_set& newlyContributedCrossContextBlocks) { if (req->isContextInitState() && req->isFirstContextChunk()) { if (kvCacheManager.isEnableBlockReuse()) { auto uniqueTokens = req->getUniqueTokens(0); if (oneManagerBeneficialToSkip(kvCacheManager, uniqueTokens, req, newlyContributedContextBlocks)) { return true; } } if (crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()) { auto uniqueTokens = *(req->getEncoderUniqueTokens().value()); if (oneManagerBeneficialToSkip(*crossKvCacheManager, uniqueTokens, req, newlyContributedCrossContextBlocks)) { return true; } } } return false; } } // namespace MaxRequestsScheduler::MaxRequestsScheduler( SizeType32 maxNumRequests, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) : BaseCapacityScheduler(noScheduleUntilState, noScheduleAfterState) , mMaxNumRequests(maxNumRequests) { } MaxUtilizationScheduler::MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) : BaseCapacityScheduler(noScheduleUntilState, noScheduleAfterState) , mMaxNumRequests(maxNumRequests) , mTwoStepsLookAhead{twoStepsLookAhead} { } GuaranteedNoEvictScheduler::GuaranteedNoEvictScheduler( SizeType32 maxNumRequests, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) : BaseCapacityScheduler(noScheduleUntilState, noScheduleAfterState) , mMaxNumRequests(maxNumRequests) { } StaticBatchScheduler::StaticBatchScheduler( SizeType32 maxNumRequests, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) : GuaranteedNoEvictScheduler(maxNumRequests, noScheduleUntilState, noScheduleAfterState) { } std::tuple MaxRequestsScheduler::operator()(RequestList const& activeRequests) const { RequestVector scheduledRequests; for (auto const& req : activeRequests) { // if request cannot be scheduled yet or request should no longer be scheduled, skip if (!req->hasReachedState(getNoScheduleUntilState()) || req->hasReachedState(getNoScheduleAfterState())) { continue; } if (scheduledRequests.size() >= static_cast(mMaxNumRequests)) { break; } if (req->isEncoderInitState() || req->isContextInitState() || req->isGenerationInProgressState()) { scheduledRequests.emplace_back(req); } } return {std::move(scheduledRequests), RequestVector{}}; } std::tuple StaticBatchScheduler::operator()( kv_cache_manager::BaseKVCacheManager const& kvCacheManager, OptionalRef crossKvCacheManager, OptionalRef peftCacheManager, RequestList const& activeRequests) const { return this->impl(kvCacheManager, crossKvCacheManager, peftCacheManager, activeRequests); } std::tuple GuaranteedNoEvictScheduler::operator()( kv_cache_manager::BaseKVCacheManager const& kvCacheManager, OptionalRef crossKvCacheManager, OptionalRef peftCacheManager, RequestList const& activeRequests) const { return impl(kvCacheManager, crossKvCacheManager, peftCacheManager, activeRequests); } template std::tuple GuaranteedNoEvictScheduler::impl( kv_cache_manager::BaseKVCacheManager const& kvCacheManager, OptionalRef crossKvCacheManager, OptionalRef peftCacheManager, RequestList const& activeRequests) const { RequestVector scheduledRequests; // Now check if we can add pending requests auto const maxPeftCachePages = peftCacheManager ? peftCacheManager->getMaxDevicePages() : std::numeric_limits::max(); // The optimization of delaying requests won't work for variable window attention bool skippingIsRelevant = (!kvCacheManager.getBlockManager().isVariableWindow()) && (!crossKvCacheManager || !crossKvCacheManager->getBlockManager().isVariableWindow()); // Keep track of blocks contributed by requests in context phase std::unordered_set newlyContributedContextBlocks; std::unordered_set newlyContributedCrossContextBlocks; if constexpr (!StaticBatchScheduling) { if (skippingIsRelevant) { std::tie(newlyContributedContextBlocks, newlyContributedCrossContextBlocks) = prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager, crossKvCacheManager); } } // If a request is already in progress, include it // If it's been allocated, it had resource to run to completion // Also keep track of blocks needed to drive all in-progress requests to completion auto reservedBlocks = kv_cache_manager::NoEvictScheduledBlocksManager(kvCacheManager); auto reservedCrossBlocks = crossKvCacheManager ? std::optional(kv_cache_manager::NoEvictScheduledBlocksManager(*crossKvCacheManager)) : std::nullopt; SizeType32 claimedPeftPages{0}; std::unordered_set uniqTaskIds{}; RequestVector pendingRequests; RequestVector pendingDisGenInitRequests; pendingRequests.reserve(activeRequests.size()); pendingDisGenInitRequests.reserve(activeRequests.size()); for (auto const& req : activeRequests) { // if request cannot be scheduled yet or request should no longer be scheduled, skip if ( // Allow disagg_generation_init requests to be scheduled, so that we'll allocate their KV cache !req->isDisaggGenerationInitState() && (!req->hasReachedState(getNoScheduleUntilState()) || req->hasReachedState(getNoScheduleAfterState()))) { continue; } if (scheduledRequests.size() >= static_cast(mMaxNumRequests)) { break; } else if (req->isGenerationInProgressState()) { scheduledRequests.emplace_back(req); reservedBlocks.decrementReservedBlocks(*req); if (reservedCrossBlocks) reservedCrossBlocks->decrementReservedBlocks(*req); bool const reqHasLora = req->getLoraTaskId().has_value(); bool const isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value()); if (isNewTask) { claimedPeftPages += peftCacheManager ? peftCacheManager->determineNumPages(req) : 0; uniqTaskIds.insert(req->getLoraTaskId().value()); } } else if (req->isDisaggGenerationInitState()) { pendingDisGenInitRequests.emplace_back(req); } else { pendingRequests.emplace_back(req); } } // If StaticBatchScheduling == true check if we can add pending requests only when no requests are active. // Otherwise, add just check that we can add pending requests. if (!StaticBatchScheduling || scheduledRequests.size() == 0) { // Now check if we can add pending requests auto availablePeftPages = maxPeftCachePages - claimedPeftPages; // Loop over pending requests and add them if they can be scheduled // Start by trying to include disagg generation init requests for (auto const& requests : {pendingDisGenInitRequests, pendingRequests}) { for (auto const& req : requests) { // if context request can reuse blocks contributed by another context request, skip if (!StaticBatchScheduling && skippingIsRelevant && !req->isDisaggGenerationInitState() && beneficialToSkip(req, kvCacheManager, crossKvCacheManager, newlyContributedContextBlocks, newlyContributedCrossContextBlocks)) { continue; } if (scheduledRequests.size() >= static_cast(mMaxNumRequests)) { break; } else if (req->isContextInitState() || req->isDisaggGenerationInitState()) { bool enoughBlocks = reservedBlocks.enoughAvailableBlocks(*req); bool enoughCrossBlocks = reservedCrossBlocks ? reservedCrossBlocks->enoughAvailableBlocks(*req) : true; bool reqHasLora = req->getLoraTaskId().has_value(); bool isNewTask = reqHasLora && !uniqTaskIds.count(req->getLoraTaskId().value()); auto neededPeftPages = isNewTask && peftCacheManager ? peftCacheManager->determineNumPages(req) : 0; if (enoughBlocks && enoughCrossBlocks && neededPeftPages <= availablePeftPages) { scheduledRequests.emplace_back(req); reservedBlocks.decrementReservedBlocks(*req); if (reservedCrossBlocks) reservedCrossBlocks->decrementReservedBlocks(*req); availablePeftPages -= neededPeftPages; if (isNewTask) { uniqTaskIds.insert(req->getLoraTaskId().value()); } } else if (!enoughBlocks || !enoughCrossBlocks) { // If one requests fails to be scheduled, break break; } } } } } return {std::move(scheduledRequests), RequestVector{}}; } // TODO(nhaber): remove forward declare and just keep the function here, right before the merge. I put it below just so // the remote diff is easier to look at/rebase conflicts bool trySchedulingRequestMaxUtilization(std::shared_ptr const& req, SizeType32 maxNumRequests, RequestVector& scheduledRequests, kv_cache_manager::MaxUtilizationScheduledBlocksManager& blocksManager, OptionalRef peftCacheManager, SizeType32& numScheduledPeftPages, std::unordered_set& seenTaskIds); std::tuple MaxUtilizationScheduler::operator()( kv_cache_manager::BaseKVCacheManager& kvCacheManager, OptionalRef peftCacheManager, RequestList const& activeRequests) const { kvCacheManager.startScheduling(); // The optimization of delaying requests won't work for variable window attention bool skippingIsRelevant = !kvCacheManager.getBlockManager().isVariableWindow(); // Keep track of number of requests and block needed for the scheduled requests auto scheduledBlocksManager = kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mTwoStepsLookAhead); SizeType32 numScheduledPeftPages{0}; std::unordered_set seenTaskIds; // Keep track of blocks contributed by requests in context phase auto [newlyContributedContextBlocks, newlyContributedCrossContextBlocks] = prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager); // Find last active in case we need to evict auto startedReqLambda = [this](std::shared_ptr const& req) { return (req->hasReachedState(getNoScheduleUntilState()) && !req->hasReachedState(getNoScheduleAfterState()) && ((req->isContextInitState() && !req->isFirstContextChunk()) || req->isGenerationInProgressState())); }; RequestVector scheduledRequests; RequestVector pausedRequests; auto reqItEnd = std::end(activeRequests); for (auto reqIt = std::begin(activeRequests); reqIt != reqItEnd;) { auto const& req = *reqIt; TLLM_LOG_DEBUG("MaxUtilizationScheduler: scheduling request ID %lu", req->mRequestId); // if request cannot be scheduled yet or request should no longer be scheduled, skip if ( // Allow disagg_generation_init requests to be scheduled, so that we'll allocate their KV cache !req->isDisaggGenerationInitState() && (!req->hasReachedState(getNoScheduleUntilState()) || req->hasReachedState(getNoScheduleAfterState()))) { TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu cannot / should not be scheduled", req->mRequestId); reqIt++; continue; } // if context request can reuse blocks contributed by another context request, skip if (skippingIsRelevant && beneficialToSkip( req, kvCacheManager, std::nullopt, newlyContributedContextBlocks, newlyContributedCrossContextBlocks)) { reqIt++; continue; } bool const wasScheduled = trySchedulingRequestMaxUtilization(req, mMaxNumRequests, scheduledRequests, scheduledBlocksManager, peftCacheManager, numScheduledPeftPages, seenTaskIds); if (wasScheduled) { TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu -> start", req->mRequestId); reqIt++; } else { auto const rbegin = std::reverse_iterator(reqItEnd); auto const rend = std::reverse_iterator(reqIt); auto const lastStartedReqIt = std::find_if(rbegin, rend, startedReqLambda); if (lastStartedReqIt != rend) { // If we can't allocate a started request, we need to start freeing started requests // from the end of the vector and try again // Here we simulate freeing the kvCache blocks associated with that sequence kvCacheManager.schedulingRemoveSequence((*lastStartedReqIt)->mRequestId); pausedRequests.emplace_back(*lastStartedReqIt); TLLM_LOG_DEBUG("MaxUtilizationScheduler: request ID %lu -> pause", (*lastStartedReqIt)->mRequestId); reqItEnd = std::next(lastStartedReqIt).base(); } else { break; } } } return {std::move(scheduledRequests), std::move(pausedRequests)}; } bool trySchedulingRequestMaxUtilization(std::shared_ptr const& req, SizeType32 maxNumRequests, RequestVector& scheduledRequests, kv_cache_manager::MaxUtilizationScheduledBlocksManager& blocksManager, OptionalRef peftCacheManager, SizeType32& numScheduledPeftPages, std::unordered_set& seenTaskIds) { if (scheduledRequests.size() < static_cast(maxNumRequests)) { bool reqHasLora = req->getLoraTaskId().has_value(); bool isNewTask = reqHasLora && !seenTaskIds.count(req->getLoraTaskId().value()); SizeType32 numRequiredPeftPages = (isNewTask && peftCacheManager) ? peftCacheManager->determineNumPages(req) : 0; TLLM_LOG_DEBUG( "MaxUtilizationScheduler: request ID %lu required peft pages: %i", req->mRequestId, numRequiredPeftPages); auto const scheduledBlocksIfFitsKvCache = blocksManager.prepareNewNumberOfBlocksIfWeEndUpScheduling(*req); bool fitsPeft = (peftCacheManager ? numRequiredPeftPages + numScheduledPeftPages <= peftCacheManager->getMaxDevicePages() : true); if (scheduledBlocksIfFitsKvCache && fitsPeft) { blocksManager.updateScheduledBlocks(scheduledBlocksIfFitsKvCache.value()); numScheduledPeftPages += numRequiredPeftPages; TLLM_LOG_DEBUG("MaxUtilizationScheduler: scheduled peft pages: %i", numRequiredPeftPages); scheduledRequests.emplace_back(req); if (isNewTask) { seenTaskIds.insert(req->getLoraTaskId().value()); } return true; } } return false; } CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool hasKvCacheManager, bool twoStepsLookAhead, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) { if (!hasKvCacheManager) { mScheduler = MaxRequestsScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState}; } else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kMAX_UTILIZATION) { mScheduler = MaxUtilizationScheduler{maxNumRequests, twoStepsLookAhead, noScheduleUntilState, noScheduleAfterState}; } else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT) { mScheduler = GuaranteedNoEvictScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState}; } else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kSTATIC_BATCH) { mScheduler = StaticBatchScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState}; } else { throw std::runtime_error("Unsupported capacity scheduler policy"); } } std::tuple CapacityScheduler::operator()(RequestList const& activeRequests, OptionalRef kvCacheManager, OptionalRef peftCacheManager, OptionalRef crossKvCacheManager) const { NVTX3_SCOPED_RANGE(capacitySchedulerScheduling); return std::visit( [&activeRequests, &kvCacheManager, &crossKvCacheManager, &peftCacheManager]( auto const& scheduler) -> std::tuple { RequestVector tmpFittingRequests; RequestVector pausedRequests; if constexpr (std::is_same_v, MaxRequestsScheduler>) { std::tie(tmpFittingRequests, pausedRequests) = scheduler(activeRequests); } else if constexpr (std::is_same_v, MaxUtilizationScheduler>) { std::tie(tmpFittingRequests, pausedRequests) = scheduler(*kvCacheManager, peftCacheManager, activeRequests); } else if constexpr (std::is_same_v, GuaranteedNoEvictScheduler> || std::is_same_v, StaticBatchScheduler>) { std::tie(tmpFittingRequests, pausedRequests) = scheduler(*kvCacheManager, crossKvCacheManager, peftCacheManager, activeRequests); } else { throw std::runtime_error("Unsupported capacity scheduler policy"); } TLLM_LOG_DEBUG("[Summary] Capacity scheduler allows %d requests, pauses %d requests", tmpFittingRequests.size(), pausedRequests.size()); RequestVector fittingRequests; RequestVector fittingDisaggGenInitRequests; for (auto const& llmReq : tmpFittingRequests) { if (llmReq->isDisaggGenerationInitState()) { fittingDisaggGenInitRequests.push_back(llmReq); } else { fittingRequests.push_back(llmReq); } } return {std::move(fittingRequests), std::move(fittingDisaggGenInitRequests), std::move(pausedRequests)}; }, mScheduler); } } // namespace tensorrt_llm::batch_manager