[TRTLLM-10334] [feat] Support overlap scheduler for disagg ctx instances (#10755)

Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2026-01-24 11:29:37 +08:00 committed by GitHub
parent 58dc4bea9c
commit da967d0bd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 46 additions and 44 deletions

View File

@ -104,7 +104,7 @@ class BatchState:
iter_start_time: float = 0
iter_stats: IterationStats = None
ctx_transmission_reqs: list[LlmRequest] = None
all_requests: list[LlmRequest] = None
@dataclasses.dataclass
@ -1804,6 +1804,7 @@ class PyExecutor:
if self.previous_batch is not None and should_process_previous_batch:
self._update_requests(self.previous_batch.sample_state)
self._send_kv_async(self.previous_batch.all_requests)
if self.drafter is not None and self.use_spec_decode and should_process_previous_batch:
# Cleanup previous draft resources used in the draft model
@ -1829,9 +1830,6 @@ class PyExecutor:
self._update_request_states(scheduled_batch)
ctx_transmission_reqs = self._send_kv_async(
scheduled_batch.all_requests())
if self.previous_batch is not None and should_process_previous_batch:
self._process_previous_batch()
else:
@ -1846,7 +1844,7 @@ class PyExecutor:
sample_state=sample_state,
iter_start_time=iter_start_time,
iter_stats=iter_stats,
ctx_transmission_reqs=ctx_transmission_reqs)
all_requests=scheduled_batch.all_requests())
elif not can_queue_this_rank:
# If the batch is empty on this rank, we need to clear the previous batch.
self.previous_batch = None
@ -1949,10 +1947,6 @@ class PyExecutor:
return result_tensors, num_accepted_tokens
def _process_previous_batch(self):
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
for req in self.previous_batch.ctx_transmission_reqs:
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
self._handle_canceled_requests()
finished_requests = self._handle_responses()
scheduled_requests = self.previous_batch.sample_state.scheduled_requests

View File

@ -448,15 +448,13 @@ class BaseWorker(GenerationExecutor):
context_phase_params = request.disaggregated_params.get_context_phase_params(
)
if self._is_pytorch_backend:
if not self.llm_args.disable_overlap_scheduler:
is_disaggregated = self.engine.kv_cache_transceiver is not None
if is_disaggregated and (
request_type
== tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
raise ValueError(
"Context only requests are not supported in pytorch backend when overlap is enabled."
)
if self._is_pytorch_backend and not self.llm_args.disable_overlap_scheduler \
and self.llm_args.kv_cache_config.enable_block_reuse \
and self.engine.kv_cache_transceiver is not None \
and request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY:
raise ValueError(
"Context only requests are not supported in pytorch backend when overlap is enabled with block reuse."
)
assert request.id is not None

View File

@ -144,8 +144,8 @@
"accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend": 71.2399792142678,
"accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]": 286.7775873204227537,
"accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]": 286.6778334858827293,
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]": 781.7928658421151,
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]": 270.3750694899354,
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]": 781.7928658421151,
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]": 270.3750694899354,
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2]": 195.4896494857967,
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4]": 205.93911361903884,
"accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]": 188.56422709790058,

View File

@ -524,20 +524,26 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@skip_pre_hopper
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("disable_overlap_scheduler", [False, True])
@pytest.mark.parametrize("ctx_disable_overlap_scheduler", [False, True])
@pytest.mark.parametrize("gen_disable_overlap_scheduler", [False, True])
@pytest.mark.parametrize("ctx_enable_block_reuse", [True, False])
@pytest.mark.parametrize("gen_enable_block_reuse", [True, False])
def test_auto_dtype(self, disable_overlap_scheduler, ctx_enable_block_reuse,
def test_auto_dtype(self, ctx_disable_overlap_scheduler,
gen_disable_overlap_scheduler, ctx_enable_block_reuse,
gen_enable_block_reuse):
if ctx_enable_block_reuse and not ctx_disable_overlap_scheduler:
pytest.skip(
"Skip this test because overlap scheduler is not supported with block reuse for context server"
)
ctx_server_config = {
"disable_overlap_scheduler": True,
"disable_overlap_scheduler": ctx_disable_overlap_scheduler,
"kv_cache_config": {
"enable_block_reuse": ctx_enable_block_reuse
}
}
ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
gen_server_config = {
"disable_overlap_scheduler": disable_overlap_scheduler,
"disable_overlap_scheduler": gen_disable_overlap_scheduler,
"kv_cache_config": {
"enable_block_reuse": gen_enable_block_reuse
}

View File

@ -3,7 +3,6 @@ port: 8000
model: DeepSeek-V3-Lite/bf16
free_gpu_memory_fraction: 0.25
backend: "pytorch"
disable_overlap_scheduler: True
cuda_graph_config: null
context_servers:
num_instances: 1

View File

@ -18,7 +18,6 @@ context_servers:
enable_block_reuse: false
free_gpu_memory_fraction: 0.80
dtype: fp8
disable_overlap_scheduler: true
moe_config:
backend: TRTLLM
cuda_graph_config: null
@ -44,7 +43,6 @@ generation_servers:
enable_block_reuse: false
free_gpu_memory_fraction: 0.80
dtype: fp8
disable_overlap_scheduler: true
moe_config:
backend: TRTLLM
cuda_graph_config:

View File

@ -16,7 +16,6 @@ context_servers:
pipeline_parallel_size: 1
print_iter_log: true
cuda_graph_config: null
disable_overlap_scheduler: true
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.05

View File

@ -12,9 +12,9 @@ context_servers:
tensor_parallel_size: 1
pipeline_parallel_size: 1
kv_cache_config:
enable_block_reuse: False
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: True
cache_transceiver_config:
backend: DEFAULT
urls:
@ -27,9 +27,9 @@ generation_servers:
max_num_tokens: 4096
max_seq_len: 4096
kv_cache_config:
enable_block_reuse: False
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
disable_overlap_scheduler: False
cache_transceiver_config:
backend: DEFAULT
urls:

View File

@ -304,8 +304,8 @@ accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[F
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]

View File

@ -162,8 +162,8 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]

View File

@ -176,8 +176,8 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]

View File

@ -33,14 +33,22 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-False-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-False-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-True-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-True-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
- disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16]
# llmapi