mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[nvbugs/5274894] fix: Sort requests for functional correctness and performance (adapted from #4608) (#4621)
- Moved sorting related logic to a dedicated function for better clarity and maintainability. - Enhanced sorting logic to separate finished context requests from ongoing ones before sorting by Lora task ID. - Updated function documentation to reflect the sorting behavior and its purpose. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
fd27f89df6
commit
93a54457ac
@ -16,10 +16,9 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
|
||||
#include "tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h"
|
||||
#include "tensorrt_llm/common/nvtxUtils.h"
|
||||
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
@ -310,13 +309,7 @@ std::tuple<RequestVector, RequestVector> MicroBatchScheduler::operator()(Request
|
||||
}
|
||||
}
|
||||
|
||||
if (!allContextRequestsFit)
|
||||
{
|
||||
// Move context requests that reached the last context chunk to the end of the vector.
|
||||
// This order is required for moveFinishedContextRequestsToGeneration.
|
||||
std::partition(contextRequests.begin(), contextRequests.end(),
|
||||
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
|
||||
}
|
||||
utils::sortRequests(contextRequests, generationRequests, !allContextRequestsFit);
|
||||
|
||||
TLLM_LOG_DEBUG(
|
||||
"batchSize (num ctx/enc requests + num gen requests): %u", contextRequests.size() + generationRequests.size());
|
||||
|
||||
@ -978,8 +978,6 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
|
||||
}
|
||||
}
|
||||
|
||||
utils::sortByLoraId(currRequests);
|
||||
|
||||
(*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests);
|
||||
|
||||
if (mKvCacheManager)
|
||||
|
||||
@ -39,17 +39,32 @@ TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector
|
||||
return requestIds;
|
||||
}
|
||||
|
||||
void sortByLoraId(ScheduledRequests& scheduledRequests)
|
||||
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto sortRequests = [](RequestVector& requests)
|
||||
auto sortByLoraId = [](RequestVector::iterator begin, RequestVector::iterator end)
|
||||
{
|
||||
std::sort(requests.begin(), requests.end(),
|
||||
[](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
|
||||
std::sort(
|
||||
begin, end, [](auto const& lhs, auto const& rhs) { return lhs->getLoraTaskId() < rhs->getLoraTaskId(); });
|
||||
};
|
||||
sortRequests(scheduledRequests.contextRequests);
|
||||
sortRequests(scheduledRequests.generationRequests);
|
||||
|
||||
if (chunksPresent)
|
||||
{
|
||||
// Move context requests that reached the last context chunk to the end of the vector.
|
||||
// This order is required for moveFinishedContextRequestsToGeneration.
|
||||
auto firstFinished = std::partition(contextRequests.begin(), contextRequests.end(),
|
||||
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
|
||||
|
||||
// Sort context requests by lora task id, but keep finished requests separate.
|
||||
sortByLoraId(contextRequests.begin(), firstFinished);
|
||||
sortByLoraId(firstFinished, contextRequests.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
sortByLoraId(contextRequests.begin(), contextRequests.end());
|
||||
}
|
||||
sortByLoraId(generationRequests.begin(), generationRequests.end());
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -63,7 +78,9 @@ void moveFinishedContextRequestsToGeneration(ScheduledRequests& scheduledRequest
|
||||
auto firstFinished = std::find_if(
|
||||
contextRequests.begin(), contextRequests.end(), [](auto const& llmReq) { return llmReq->isContextFinished(); });
|
||||
TLLM_LOG_DEBUG(
|
||||
"Moving %ld finished context requests to generation.", std::distance(firstFinished, contextRequests.end()));
|
||||
"Found %ld unfinished chunked context requests. Found %ld finished context requests, moving them to "
|
||||
"generation.",
|
||||
std::distance(contextRequests.begin(), firstFinished), std::distance(firstFinished, contextRequests.end()));
|
||||
generationRequests.insert(generationRequests.begin(), std::make_move_iterator(firstFinished),
|
||||
std::make_move_iterator(contextRequests.end()));
|
||||
contextRequests.erase(firstFinished, contextRequests.end());
|
||||
|
||||
@ -35,7 +35,13 @@ using OptionalRef = common::OptionalRef<T>;
|
||||
|
||||
TensorPtr collectRequestIds(RequestVector const& contextRequests, RequestVector const& generationRequests);
|
||||
|
||||
void sortByLoraId(ScheduledRequests& scheduledRequests);
|
||||
//! @brief Sort requests for functional correctness and performance.
|
||||
//! @details Sort context requests for moveFinishedContextRequestsToGeneration.
|
||||
//! Sort requests by lora task id for performance.
|
||||
//! @param contextRequests The context requests.
|
||||
//! @param generationRequests The generation requests.
|
||||
//! @param chunksPresent Whether context chunks are present.
|
||||
void sortRequests(RequestVector& contextRequests, RequestVector& generationRequests, bool chunksPresent);
|
||||
|
||||
//! @brief Move finished context requests to generation requests.
|
||||
//! @details This function assumes that the context requests are sorted so that requests with isLastContextChunk() are
|
||||
|
||||
Loading…
Reference in New Issue
Block a user