refactor: Unify request order in TRT and PyTorch workflow (#4096)

* chore: Partition context requests in MicroBatchScheduler

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

* fixup! chore: Partition context requests in MicroBatchScheduler

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>

---------

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
Robin Kobus 2025-05-20 18:49:27 +02:00 committed by GitHub
parent f038218f83
commit 8564c5a41f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 8 deletions

View File

@ -310,6 +310,14 @@ std::tuple<RequestVector, RequestVector> MicroBatchScheduler::operator()(Request
}
}
if (!allContextRequestsFit)
{
// Move context requests that reached the last context chunk to the end of the vector.
// This order is required for moveFinishedContextRequestsToGeneration.
std::partition(contextRequests.begin(), contextRequests.end(),
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
}
TLLM_LOG_DEBUG(
"batchSize (num ctx/enc requests + num gen requests): %u", contextRequests.size() + generationRequests.size());
TLLM_LOG_DEBUG("batchNumTokens (num ctx/enc input tokens + num gen input tokens) / maxNumTokens: %d / %d",

View File

@ -953,11 +953,6 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
= (*mMicroBatchScheduler)(fittingRequests, mInflightReqIds, mMaxBatchSizeRuntime, mMaxNumTokensRuntime);
TLLM_CHECK(currRequests.size() <= static_cast<size_t>(getMaxBatchSize()));
// Move context requests that reached the last context chunk to the end of the vector.
// This order is required for moveFinishedContextRequestsToGeneration.
std::partition(currRequests.contextRequests.begin(), currRequests.contextRequests.end(),
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
(*mPauseRequests)(requestsToPause, mInflightReqIds, mReqIdsToPause, false, *mSeqSlotManager, mKvCacheManager,
mCrossKvCacheManager, mPeftCacheManager);

View File

@ -268,6 +268,10 @@ class TorchSampler(Sampler):
} for token, logprob in zip(tokens, log_probs.tolist())]
request.py_result.append_log_probs([token_log_probs])
if hasattr(scheduled_requests, 'chunked_requests'):
request_idx += len(scheduled_requests.chunked_requests)
token_idx += len(scheduled_requests.chunked_requests)
for request in scheduled_requests.context_requests:
if request.get_context_remaining_length() != 0:
advance_idx()
@ -282,9 +286,6 @@ class TorchSampler(Sampler):
request.py_decoding_iter += 1
advance_idx()
if hasattr(scheduled_requests, 'chunked_requests'):
request_idx += len(scheduled_requests.chunked_requests)
extend_requests = []
generation_requests = []
for request in scheduled_requests.generation_requests:

File diff suppressed because one or more lines are too long