fix: max_num_sequences calculation with overlap scheduling (#4532)

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Co-authored-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Robin Kobus 2025-06-03 09:31:22 +02:00 committed by GitHub
parent 320195dc0d
commit b9263a8e10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 25 additions and 31 deletions

View File

@ -88,7 +88,7 @@ private:
class MaxUtilizationScheduler : public BaseCapacityScheduler
{
public:
MaxUtilizationScheduler(SizeType32 maxNumRequests, bool manyMicroBatches,
MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
@ -98,8 +98,8 @@ public:
private:
SizeType32 mMaxNumRequests;
/// @brief Boolean that indicates if multiple micro batches might be in flight
bool mManyMicroBatches;
/// @brief Boolean that indicates if two step lookahead is enabled
bool mTwoStepsLookAhead;
};
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
@ -146,7 +146,7 @@ public:
constexpr static auto name{"CapacityScheduler"};
explicit CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy,
bool hasKvCacheManager, std::optional<bool> manyMicroBatches = std::nullopt,
bool hasKvCacheManager, bool twoStepsLookAhead = false,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);

View File

@ -129,11 +129,11 @@ MaxRequestsScheduler::MaxRequestsScheduler(
{
}
MaxUtilizationScheduler::MaxUtilizationScheduler(SizeType32 maxNumRequests, bool manyMicroBatches,
MaxUtilizationScheduler::MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
: BaseCapacityScheduler(noScheduleUntilState, noScheduleAfterState)
, mMaxNumRequests(maxNumRequests)
, mManyMicroBatches{manyMicroBatches}
, mTwoStepsLookAhead{twoStepsLookAhead}
{
}
@ -346,7 +346,7 @@ std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
// Keep track of number of requests and block needed for the scheduled requests
auto scheduledBlocksManager
= kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mManyMicroBatches);
= kv_cache_manager::MaxUtilizationScheduledBlocksManager(kvCacheManager, mTwoStepsLookAhead);
SizeType32 numScheduledPeftPages{0};
std::unordered_set<uint64_t> seenTaskIds;
@ -456,8 +456,8 @@ bool trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req,
}
CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests,
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool hasKvCacheManager,
std::optional<bool> manyMicroBatches, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool hasKvCacheManager, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
{
if (!hasKvCacheManager)
{
@ -465,8 +465,8 @@ CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests,
}
else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kMAX_UTILIZATION)
{
mScheduler = MaxUtilizationScheduler{
maxNumRequests, manyMicroBatches ? *manyMicroBatches : false, noScheduleUntilState, noScheduleAfterState};
mScheduler
= MaxUtilizationScheduler{maxNumRequests, twoStepsLookAhead, noScheduleUntilState, noScheduleAfterState};
}
else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT)
{

View File

@ -75,8 +75,8 @@ TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldC
// handling of maximizing utilization or pause/evict
// TODO: finer control on encoder requests scheduling
mCapacityScheduler = std::make_unique<tensorrt_llm::batch_manager::CapacityScheduler>(
getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false,
std::nullopt, LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false, false,
LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
mMicroBatchScheduler = std::make_unique<tensorrt_llm::batch_manager::MicroBatchScheduler>(
std::nullopt, mModelConfig.getMaxInputLen(), LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);

View File

@ -116,9 +116,7 @@ public:
? optionalParams.kvCacheConfig.sinkTokenLength.value()
: 0;
auto const numBatches
= worldConfig.isPipelineParallel() ? worldConfig.getPipelineParallelism() : (mEnableTrtOverlap ? 2 : 1);
mMaxNumSequences = numBatches * mMaxBatchSize;
mMaxNumSequences = mMaxBatchSize * worldConfig.getPipelineParallelism();
auto const numTotalAttenLayers = modelConfig.getNbAttentionLayers();
auto const numRepeatsAttenWindow = numTotalAttenLayers / mMaxAttentionWindowVec.size();

View File

@ -412,7 +412,7 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer
ctxChunkConfig.value().chunkUnitSize, mKvCacheManager->getTokensPerBlock());
}
mCapacityScheduler = std::make_unique<CapacityScheduler>(getMaxBatchSize() * mNumMicroBatches,
mCapacityScheduler = std::make_unique<CapacityScheduler>(getMaxNumSequences(),
optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), mKvCacheManager != nullptr, mNumMicroBatches > 1);
mMicroBatchScheduler = std::make_unique<MicroBatchScheduler>(ctxChunkConfig, maxContextLength);

View File

@ -56,7 +56,7 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
py::class_<CapacityScheduler>(m, CapacityScheduler::name)
.def(py::init<SizeType32, executor::CapacitySchedulerPolicy, bool, bool, LlmRequestState, LlmRequestState>(),
py::arg("max_num_requests"), py::arg("capacity_scheduler_policy"), py::arg("has_kv_cache_manager"),
py::arg("many_micro_batches") = false,
py::arg("two_step_lookahead") = false,
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
"LlmRequestState.GENERATION_COMPLETE"))

View File

@ -405,13 +405,9 @@ def create_py_executor_instance(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules)
if mapping.has_pp():
num_micro_batches = mapping.pp_size
else:
num_micro_batches = 1 if pytorch_backend_config.disable_overlap_scheduler else 2
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
resources["seq_slot_manager"] = SeqSlotManager(
executor_config.max_batch_size * num_micro_batches)
resources["seq_slot_manager"] = SeqSlotManager(max_num_sequences)
resource_manager = ResourceManager(resources)
@ -422,10 +418,11 @@ def create_py_executor_instance(
last=True)
capacity_scheduler = BindCapacityScheduler(
executor_config.max_batch_size,
max_num_sequences,
kv_cache_manager.impl if kv_cache_manager is not None else None,
executor_config.scheduler_config.capacity_scheduler_policy,
num_micro_batches=num_micro_batches)
two_step_lookahead=mapping.has_pp()
or not pytorch_backend_config.disable_overlap_scheduler)
mb_scheduler = BindMicroBatchScheduler(executor_config.max_batch_size,
executor_config.max_num_tokens,
ctx_chunk_config)

View File

@ -77,16 +77,16 @@ class BindCapacityScheduler(CapacityScheduler):
kv_cache_manager,
scheduler_policy: tb_executor.CapacitySchedulerPolicy = tb_executor.
CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
num_micro_batches: int = 1,
two_step_lookahead: bool = False,
):
super(BindCapacityScheduler, self).__init__()
self.kv_cache_manager = kv_cache_manager
self.impl = tb_internal.algorithms.CapacityScheduler(
max_num_requests=max_num_requests * num_micro_batches,
max_num_requests=max_num_requests,
capacity_scheduler_policy=scheduler_policy,
has_kv_cache_manager=kv_cache_manager is not None,
many_micro_batches=num_micro_batches > 1,
two_step_lookahead=two_step_lookahead,
no_schedule_until_state=LlmRequestState.CONTEXT_INIT,
no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE)

View File

@ -191,8 +191,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@skip_pre_hopper
def test_fp8_llm_sampler(self):
model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
pytorch_config = dict(enable_trtllm_sampler=True)
llm = LLM(model_path, **pytorch_config)
llm = LLM(model_path, enable_trtllm_sampler=True, max_batch_size=256)
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
sampling_params = SamplingParams(