From 93a54457ac1d103e7788e6f581506336e47d1d6f Mon Sep 17 00:00:00 2001 From: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Date: Mon, 26 May 2025 11:10:55 +0200 Subject: [PATCH] [nvbugs/5274894] fix: Sort requests for functional correctness and performance (adapted from #4608) (#4621) - Moved sorting related logic to a dedicated function for better clarity and maintainability. - Enhanced sorting logic to separate finished context requests from ongoing ones before sorting by Lora task ID. - Updated function documentation to reflect the sorting behavior and its purpose. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --- .../batch_manager/microBatchScheduler.cpp | 11 ++----- .../trtGptModelInflightBatching.cpp | 2 -- .../utils/inflightBatchingUtils.cpp | 31 ++++++++++++++----- .../utils/inflightBatchingUtils.h | 8 ++++- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp b/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp index 169108e3da..6a2dc46d53 100644 --- a/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp +++ b/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp @@ -16,10 +16,9 @@ */ #include "tensorrt_llm/batch_manager/microBatchScheduler.h" +#include "tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h" #include "tensorrt_llm/common/nvtxUtils.h" -namespace tle = tensorrt_llm::executor; - namespace tensorrt_llm::batch_manager { @@ -310,13 +309,7 @@ std::tuple MicroBatchScheduler::operator()(Request } } - if (!allContextRequestsFit) - { - // Move context requests that reached the last context chunk to the end of the vector. - // This order is required for moveFinishedContextRequestsToGeneration. - std::partition(contextRequests.begin(), contextRequests.end(), - [](auto const& llmReq) { return !llmReq->isLastContextChunk(); }); - } + utils::sortRequests(contextRequests, generationRequests, !allContextRequestsFit); TLLM_LOG_DEBUG( "batchSize (num ctx/enc requests + num gen requests): %u", contextRequests.size() + generationRequests.size()); diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 2593a7d425..0a1b6f03ec 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -978,8 +978,6 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests } } - utils::sortByLoraId(currRequests); - (*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests); if (mKvCacheManager) diff --git a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp index b34bd9cf68..466146f307 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp +++ b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp @@ -39,17 +39,32 @@ TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector return requestIds; } -void sortByLoraId(ScheduledRequests& scheduledRequests) +void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto sortRequests = [](RequestVector& requests) + auto sortByLoraId = [](RequestVector::iterator begin, RequestVector::iterator end) { - std::sort(requests.begin(), requests.end(), - [](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); }); + std::sort( + begin, end, [](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); }); }; - sortRequests(scheduledRequests.contextRequests); - sortRequests(scheduledRequests.generationRequests); + + if (chunksPresent) + { + // Move context requests that reached the last context chunk to the end of the vector. + // This order is required for moveFinishedContextRequestsToGeneration. + auto firstFinished = std::partition(contextRequests.begin(), contextRequests.end(), + [](auto const& llmReq) { return !llmReq->isLastContextChunk(); }); + + // Sort context requests by lora task id, but keep finished requests separate. + sortByLoraId(contextRequests.begin(), firstFinished); + sortByLoraId(firstFinished, contextRequests.end()); + } + else + { + sortByLoraId(contextRequests.begin(), contextRequests.end()); + } + sortByLoraId(generationRequests.begin(), generationRequests.end()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -63,7 +78,9 @@ void moveFinishedContextRequestsToGeneration(ScheduledRequests& scheduledRequest auto firstFinished = std::find_if( contextRequests.begin(), contextRequests.end(), [](auto const& llmReq) { return llmReq->isContextFinished(); }); TLLM_LOG_DEBUG( - "Moving %ld finished context requests to generation.", std::distance(firstFinished, contextRequests.end())); + "Found %ld unfinished chunked context requests. Found %ld finished context requests, moving them to " + "generation.", + std::distance(contextRequests.begin(), firstFinished), std::distance(firstFinished, contextRequests.end())); generationRequests.insert(generationRequests.begin(), std::make_move_iterator(firstFinished), std::make_move_iterator(contextRequests.end())); contextRequests.erase(firstFinished, contextRequests.end()); diff --git a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h index 2095dce2d0..d34b3e4e38 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h +++ b/cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h @@ -35,7 +35,13 @@ using OptionalRef = common::OptionalRef; TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector const& generationRequests); -void sortByLoraId(ScheduledRequests& scheduledRequests); +//! @brief Sort requests for functional correctness and performance. +//! @details Sort context requests for moveFinishedContextRequestsToGeneration. +//! Sort requests by lora task id for performance. +//! @param contextRequests The context requests. +//! @param generationRequests The generation requests. +//! @param chunksPresent Whether context chunks are present. +void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent); //! @brief Move finished context requests to generation requests. //! @details This function assumes that the context requests are sorted so that requests with isLastContextChunk() are