mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
refactor: Unify request order in TRT and PyTorch workflow (#4096)
* chore: Partition context requests in MicroBatchScheduler Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> * fixup! chore: Partition context requests in MicroBatchScheduler Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> --------- Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
f038218f83
commit
8564c5a41f
@ -310,6 +310,14 @@ std::tuple<RequestVector, RequestVector> MicroBatchScheduler::operator()(Request
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!allContextRequestsFit)
|
||||||
|
{
|
||||||
|
// Move context requests that reached the last context chunk to the end of the vector.
|
||||||
|
// This order is required for moveFinishedContextRequestsToGeneration.
|
||||||
|
std::partition(contextRequests.begin(), contextRequests.end(),
|
||||||
|
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
|
||||||
|
}
|
||||||
|
|
||||||
TLLM_LOG_DEBUG(
|
TLLM_LOG_DEBUG(
|
||||||
"batchSize (num ctx/enc requests + num gen requests): %u", contextRequests.size() + generationRequests.size());
|
"batchSize (num ctx/enc requests + num gen requests): %u", contextRequests.size() + generationRequests.size());
|
||||||
TLLM_LOG_DEBUG("batchNumTokens (num ctx/enc input tokens + num gen input tokens) / maxNumTokens: %d / %d",
|
TLLM_LOG_DEBUG("batchNumTokens (num ctx/enc input tokens + num gen input tokens) / maxNumTokens: %d / %d",
|
||||||
|
|||||||
@ -953,11 +953,6 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
|
|||||||
= (*mMicroBatchScheduler)(fittingRequests, mInflightReqIds, mMaxBatchSizeRuntime, mMaxNumTokensRuntime);
|
= (*mMicroBatchScheduler)(fittingRequests, mInflightReqIds, mMaxBatchSizeRuntime, mMaxNumTokensRuntime);
|
||||||
TLLM_CHECK(currRequests.size() <= static_cast<size_t>(getMaxBatchSize()));
|
TLLM_CHECK(currRequests.size() <= static_cast<size_t>(getMaxBatchSize()));
|
||||||
|
|
||||||
// Move context requests that reached the last context chunk to the end of the vector.
|
|
||||||
// This order is required for moveFinishedContextRequestsToGeneration.
|
|
||||||
std::partition(currRequests.contextRequests.begin(), currRequests.contextRequests.end(),
|
|
||||||
[](auto const& llmReq) { return !llmReq->isLastContextChunk(); });
|
|
||||||
|
|
||||||
(*mPauseRequests)(requestsToPause, mInflightReqIds, mReqIdsToPause, false, *mSeqSlotManager, mKvCacheManager,
|
(*mPauseRequests)(requestsToPause, mInflightReqIds, mReqIdsToPause, false, *mSeqSlotManager, mKvCacheManager,
|
||||||
mCrossKvCacheManager, mPeftCacheManager);
|
mCrossKvCacheManager, mPeftCacheManager);
|
||||||
|
|
||||||
|
|||||||
@ -268,6 +268,10 @@ class TorchSampler(Sampler):
|
|||||||
} for token, logprob in zip(tokens, log_probs.tolist())]
|
} for token, logprob in zip(tokens, log_probs.tolist())]
|
||||||
request.py_result.append_log_probs([token_log_probs])
|
request.py_result.append_log_probs([token_log_probs])
|
||||||
|
|
||||||
|
if hasattr(scheduled_requests, 'chunked_requests'):
|
||||||
|
request_idx += len(scheduled_requests.chunked_requests)
|
||||||
|
token_idx += len(scheduled_requests.chunked_requests)
|
||||||
|
|
||||||
for request in scheduled_requests.context_requests:
|
for request in scheduled_requests.context_requests:
|
||||||
if request.get_context_remaining_length() != 0:
|
if request.get_context_remaining_length() != 0:
|
||||||
advance_idx()
|
advance_idx()
|
||||||
@ -282,9 +286,6 @@ class TorchSampler(Sampler):
|
|||||||
request.py_decoding_iter += 1
|
request.py_decoding_iter += 1
|
||||||
advance_idx()
|
advance_idx()
|
||||||
|
|
||||||
if hasattr(scheduled_requests, 'chunked_requests'):
|
|
||||||
request_idx += len(scheduled_requests.chunked_requests)
|
|
||||||
|
|
||||||
extend_requests = []
|
extend_requests = []
|
||||||
generation_requests = []
|
generation_requests = []
|
||||||
for request in scheduled_requests.generation_requests:
|
for request in scheduled_requests.generation_requests:
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user