[nvbugs/5274894] fix: Sort requests for functional correctness and performance (adapted from #4608) (#4621)

- Moved sorting related logic to a dedicated function for better clarity and maintainability.
- Enhanced sorting logic to separate finished context requests from ongoing ones before sorting by Lora task ID.
- Updated function documentation to reflect the sorting behavior and its purpose.

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
Robin Kobus 2025-05-26 11:10:55 +02:00 committed by GitHub
parent fd27f89df6
commit 93a54457ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 19 deletions

View File

@ -16,10 +16,9 @@
*/
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
#include "tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h"
#include "tensorrt_llm/common/nvtxUtils.h"
namespace tle = tensorrt_llm::executor;
namespace tensorrt_llm::batch_manager
{
@ -310,13 +309,7 @@ std::tuple<RequestVector, RequestVector> MicroBatchScheduler::operator()(Request
}
}
if (!allContextRequestsFit)
{
// Move context requests that reached the last context chunk to the end of the vector.
// This order is required for moveFinishedContextRequestsToGeneration.
std::partition(contextRequests.begin(), contextRequests.end(),
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
}
utils::sortRequests(contextRequests, generationRequests, !allContextRequestsFit);
TLLM_LOG_DEBUG(
"batchSize (num ctx/enc requests + num gen requests): %u", contextRequests.size() + generationRequests.size());

View File

@ -978,8 +978,6 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
}
}
utils::sortByLoraId(currRequests);
(*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests);
if (mKvCacheManager)

View File

@ -39,17 +39,32 @@ TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector
return requestIds;
}
void sortByLoraId(ScheduledRequests& scheduledRequests)
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto sortRequests = [](RequestVector& requests)
auto sortByLoraId = [](RequestVector::iterator begin, RequestVector::iterator end)
{
std::sort(requests.begin(), requests.end(),
[](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
std::sort(
begin, end, [](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
};
sortRequests(scheduledRequests.contextRequests);
sortRequests(scheduledRequests.generationRequests);
if (chunksPresent)
{
// Move context requests that reached the last context chunk to the end of the vector.
// This order is required for moveFinishedContextRequestsToGeneration.
auto firstFinished = std::partition(contextRequests.begin(), contextRequests.end(),
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
// Sort context requests by lora task id, but keep finished requests separate.
sortByLoraId(contextRequests.begin(), firstFinished);
sortByLoraId(firstFinished, contextRequests.end());
}
else
{
sortByLoraId(contextRequests.begin(), contextRequests.end());
}
sortByLoraId(generationRequests.begin(), generationRequests.end());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -63,7 +78,9 @@ void moveFinishedContextRequestsToGeneration(ScheduledRequests& scheduledRequest
auto firstFinished = std::find_if(
contextRequests.begin(), contextRequests.end(), [](auto const& llmReq) { return llmReq->isContextFinished(); });
TLLM_LOG_DEBUG(
"Moving %ld finished context requests to generation.", std::distance(firstFinished, contextRequests.end()));
"Found %ld unfinished chunked context requests. Found %ld finished context requests, moving them to "
"generation.",
std::distance(contextRequests.begin(), firstFinished), std::distance(firstFinished, contextRequests.end()));
generationRequests.insert(generationRequests.begin(), std::make_move_iterator(firstFinished),
std::make_move_iterator(contextRequests.end()));
contextRequests.erase(firstFinished, contextRequests.end());

View File

@ -35,7 +35,13 @@ using OptionalRef = common::OptionalRef<T>;
TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector const& generationRequests);
void sortByLoraId(ScheduledRequests& scheduledRequests);
//! @brief Sort requests for functional correctness and performance.
//! @details Sort context requests for moveFinishedContextRequestsToGeneration.
//! Sort requests by lora task id for performance.
//! @param contextRequests The context requests.
//! @param generationRequests The generation requests.
//! @param chunksPresent Whether context chunks are present.
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent);
//! @brief Move finished context requests to generation requests.
//! @details This function assumes that the context requests are sorted so that requests with isLastContextChunk() are