[TRTLLM-10029][scheduler] Re-implement MicroBatchScheduler and CapacityScheduler in Python (#10273)

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
Signed-off-by: Lance Liao <108499334+lancelly@users.noreply.github.com>
Co-authored-by: junq <22017000+QiJune@users.noreply.github.com>
Co-authored-by: Lanyu Liao <lancelly@users.noreply.github.com>
This commit is contained in:
Liao Lanyu 2026-01-20 10:31:13 +08:00 committed by GitHub
parent c6320d924d
commit dbb858ae0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1167 additions and 76 deletions

View File

@ -132,6 +132,7 @@ void initBindings(nb::module_& m)
.def_rw("max_new_tokens", &GenLlmReq::mMaxNewTokens)
.def_rw("sampling_config", &GenLlmReq::mSamplingConfig)
.def_prop_rw("state", &GenLlmReq::getState, &GenLlmReq::setState)
.def_prop_ro("state_value", [](GenLlmReq const& self) { return static_cast<int>(self.getState()); })
.def_prop_rw("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming)
.def_rw("end_id", &GenLlmReq::mEndId)
.def_rw("pad_id", &GenLlmReq::mPadId)
@ -175,6 +176,7 @@ void initBindings(nb::module_& m)
.def_prop_ro("is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
.def_prop_ro(
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
.def_prop_ro("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
.def_prop_ro("is_context_init_state", &GenLlmReq::isContextInitState)
.def_prop_ro("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_prop_ro("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
@ -253,7 +255,20 @@ void initBindings(nb::module_& m)
})
.def_prop_rw("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_prop_ro("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
.def_prop_rw("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
.def("get_unique_tokens", nb::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, nb::const_),
nb::arg("beam"))
.def("get_unique_tokens", nb::overload_cast<>(&GenLlmReq::getUniqueTokens, nb::const_))
.def("get_encoder_unique_tokens",
[](GenLlmReq& self)
{
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
{
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
}
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
});
nb::class_<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", nb::dynamic_attr())
.def(

View File

@ -481,6 +481,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, nb::call_guard<nb::gil_scoped_release>())
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
nb::call_guard<nb::gil_scoped_release>())
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, nb::arg("unique_tokens"),
nb::arg("llm_request"), nb::call_guard<nb::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, nb::call_guard<nb::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
nb::call_guard<nb::gil_scoped_release>())
@ -524,7 +526,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true,
nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr,
nb::arg("enable_indexer_k_cache") = false, nb::arg("indexer_k_cache_quant_block_size") = 128,
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>());
nb::arg("indexer_k_cache_index_head_dim") = 0, nb::call_guard<nb::gil_scoped_release>())
.def(
"scheduling_has_free_blocks",
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
nb::arg("num_required"), nb::arg("window_size"), nb::call_guard<nb::gil_scoped_release>())
.def_prop_ro(
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
}
void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)

View File

@ -136,6 +136,7 @@ void initBindings(pybind11::module_& m)
.def_readwrite("max_new_tokens", &GenLlmReq::mMaxNewTokens)
.def_readwrite("sampling_config", &GenLlmReq::mSamplingConfig)
.def_property("state", &GenLlmReq::getState, &GenLlmReq::setState)
.def_property_readonly("state_value", [](GenLlmReq const& self) { return static_cast<int>(self.getState()); })
.def_property("streaming", &GenLlmReq::isStreaming, &GenLlmReq::setStreaming)
.def_readwrite("end_id", &GenLlmReq::mEndId)
.def_readwrite("pad_id", &GenLlmReq::mPadId)
@ -181,6 +182,7 @@ void initBindings(pybind11::module_& m)
"is_disagg_generation_transmission_complete", &GenLlmReq::isDisaggGenerationTransmissionComplete)
.def_property_readonly(
"is_disagg_generation_transmission_in_progress", &GenLlmReq::isDisaggGenerationTransmissionInProgress)
.def_property_readonly("is_encoder_init_state", &GenLlmReq::isEncoderInitState)
.def_property_readonly("is_context_init_state", &GenLlmReq::isContextInitState)
.def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
@ -259,7 +261,20 @@ void initBindings(pybind11::module_& m)
})
.def_property("is_dummy_request", &GenLlmReq::isDummyRequest, &GenLlmReq::setIsDummyRequest)
.def_property_readonly("return_perf_metrics", &GenLlmReq::getReturnPerfMetrics)
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel);
.def_property("use_draft_model", &GenLlmReq::useDraftModel, &GenLlmReq::setUseDraftModel)
.def("get_unique_tokens", py::overload_cast<GenLlmReq::SizeType32>(&GenLlmReq::getUniqueTokens, py::const_),
py::arg("beam"))
.def("get_unique_tokens", py::overload_cast<>(&GenLlmReq::getUniqueTokens, py::const_))
.def("get_encoder_unique_tokens",
[](GenLlmReq& self)
{
auto const& encoderUniqueTokens = self.getEncoderUniqueTokens();
if (encoderUniqueTokens.has_value() && encoderUniqueTokens.value())
{
return std::optional<GenLlmReq::VecUniqueTokens>(*encoderUniqueTokens.value());
}
return std::optional<GenLlmReq::VecUniqueTokens>(std::nullopt);
});
py::classh<tb::LlmRequest, GenLlmReq>(m, "LlmRequest", pybind11::dynamic_attr())
.def(py::init<>(

View File

@ -485,6 +485,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
.def("store_context_blocks", &BaseKVCacheManager::storeContextBlocks, py::call_guard<py::gil_scoped_release>())
.def("store_blocks_for_reuse", &BaseKVCacheManager::storeBlocksForReuse,
py::call_guard<py::gil_scoped_release>())
.def("find_new_context_block", &BaseKVCacheManager::findNewContextBlock, py::arg("unique_tokens"),
py::arg("llm_request"), py::call_guard<py::gil_scoped_release>())
.def("get_cache_block_ids", &BaseKVCacheManager::getCacheBlockIds, py::call_guard<py::gil_scoped_release>())
.def("get_batch_cache_block_ids", &BaseKVCacheManager::getBatchCacheBlockIds,
py::call_guard<py::gil_scoped_release>())
@ -519,7 +521,14 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true,
py::arg("kv_connector_manager") = nullptr, py::arg("enable_indexer_k_cache") = false,
py::arg("indexer_k_cache_quant_block_size") = 128, py::arg("indexer_k_cache_index_head_dim") = 0,
py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>())
.def(
"scheduling_has_free_blocks",
[](tbk::KVCacheManager& self, SizeType32 numRequired, SizeType32 windowSize)
{ return self.getBlockManager().schedulingHasFreeBlocks(numRequired, windowSize); },
py::arg("num_required"), py::arg("window_size"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly(
"is_variable_window", [](tbk::KVCacheManager& self) { return self.getBlockManager().isVariableWindow(); });
}
void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)

View File

@ -39,7 +39,7 @@ from .resource_manager import (KVCacheManager, PeftCacheManager,
from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler,
TRTLLMSampler)
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
SimpleScheduler)
SimpleScheduler, SimpleUnifiedScheduler)
from .seq_slot_manager import SeqSlotManager
GB = 1 << 30
@ -852,15 +852,29 @@ def create_py_executor_instance(
if scheduler_capacity == 1 and mapping.enable_attention_dp and kv_cache_manager:
scheduler_capacity += 1
capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
use_python_scheduler = os.getenv("TLLM_USE_PYTHON_SCHEDULER", "0") == "1"
if use_python_scheduler:
scheduler = SimpleUnifiedScheduler(
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
kv_cache_manager=kv_cache_manager.impl
if kv_cache_manager is not None else None,
peft_cache_manager=peft_cache_manager.impl
if peft_cache_manager is not None else None,
scheduler_policy=scheduler_config.capacity_scheduler_policy,
ctx_chunk_config=ctx_chunk_config,
two_step_lookahead=mapping.has_pp(),
scheduler_capacity=scheduler_capacity)
else:
capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(max_batch_size, max_num_tokens,
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
config = model_engine.model.model_config.pretrained_config
attention_type = AttentionTypeCpp.MLA if is_mla(

View File

@ -2041,6 +2041,7 @@ class PyExecutor:
def _schedule(self):
scheduler_output = self.scheduler.schedule_request(
self.active_requests, self.inflight_req_ids)
scheduled_context_requests = scheduler_output.context_requests
if self.enable_attention_dp and self.attention_dp_enable_balance:
scheduled_context_requests = self._balance_adp_requests(
@ -2060,6 +2061,7 @@ class PyExecutor:
scheduled_requests.context_requests = scheduled_context_requests
scheduled_requests.generation_requests = scheduler_output.generation_requests
scheduled_requests.paused_requests = scheduler_output.paused_requests
return scheduled_requests, scheduler_output.fitting_disagg_gen_init_requests, scheduler_output.num_fitting_requests
@nvtx_range("_check_disagg_gen_transfer_status")

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,10 @@ def model_path():
return llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
def create_llm(model_dir, disable_overlap_scheduler, sampler_type):
def create_llm(model_dir,
disable_overlap_scheduler,
sampler_type,
env_overrides=None):
"""Create LLM with specific overlap scheduler setting"""
pytorch_config = dict(disable_overlap_scheduler=disable_overlap_scheduler,
sampler_type=sampler_type)
@ -37,14 +40,23 @@ def create_llm(model_dir, disable_overlap_scheduler, sampler_type):
**pytorch_config,
kv_cache_config=trt_kv_cache_config,
max_num_tokens=
128 # Only one request longer than max_num_tokens is required to test chunked prefill
128, # Only one request longer than max_num_tokens is required to test chunked prefill
env_overrides=env_overrides,
)
@pytest.mark.parametrize("sampler_type", ["TorchSampler", "TRTLLMSampler"])
@pytest.mark.parametrize("use_python_scheduler", [False, True],
ids=["cpp_scheduler", "python_scheduler"])
@pytest.mark.high_cuda_memory
@pytest.mark.mpi_ray_parity
def test_overlap_scheduler_consistency(model_path, test_case, sampler_type):
def test_overlap_scheduler_consistency(model_path, test_case, sampler_type,
use_python_scheduler):
# Use env_overrides to pass env var to MPI subprocess
env_overrides = {
"TLLM_USE_PYTHON_SCHEDULER": "1"
} if use_python_scheduler else {}
# Test configuration
prompts = test_case["prompts"]
max_new_tokens = test_case["max_new_tokens"]
@ -62,7 +74,8 @@ def test_overlap_scheduler_consistency(model_path, test_case, sampler_type):
# Test with overlap scheduler enabled
with create_llm(model_path,
disable_overlap_scheduler=False,
sampler_type=sampler_type) as llm:
sampler_type=sampler_type,
env_overrides=env_overrides) as llm:
outputs_with_overlap = llm.generate(prompts,
sampling_params=sampling_config,
use_tqdm=True)
@ -73,7 +86,8 @@ def test_overlap_scheduler_consistency(model_path, test_case, sampler_type):
# Test with overlap scheduler disabled
with create_llm(model_path,
disable_overlap_scheduler=True,
sampler_type=sampler_type) as llm:
sampler_type=sampler_type,
env_overrides=env_overrides) as llm:
outputs_without_overlap = llm.generate(prompts,
sampling_params=sampling_config,
use_tqdm=True)