From da967d0bd7218ae3e6a16fa12899dcf6d127eb56 Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Sat, 24 Jan 2026 11:29:37 +0800 Subject: [PATCH] [TRTLLM-10334] [feat] Support overlap scheduler for disagg ctx instances (#10755) Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 12 +++------- tensorrt_llm/executor/base_worker.py | 16 ++++++------- tests/integration/defs/.test_durations | 4 ++-- .../accuracy/test_disaggregated_serving.py | 14 +++++++---- ...tp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml | 1 - ...ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml | 2 -- ...g_config_deepseek_v3_lite_empty_batch.yaml | 1 - .../test_configs/disagg_config_overlap.yaml | 4 ++-- .../test_lists/qa/llm_function_core.txt | 4 ++-- .../qa/llm_function_core_sanity.txt | 4 ++-- .../test_lists/qa/llm_function_rtx6k.txt | 4 ++-- .../test_lists/test-db/l0_dgx_h100.yml | 24 ++++++++++++------- 12 files changed, 46 insertions(+), 44 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 39365a8714..eaf6e05098 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -104,7 +104,7 @@ class BatchState: iter_start_time: float = 0 iter_stats: IterationStats = None - ctx_transmission_reqs: list[LlmRequest] = None + all_requests: list[LlmRequest] = None @dataclasses.dataclass @@ -1804,6 +1804,7 @@ class PyExecutor: if self.previous_batch is not None and should_process_previous_batch: self._update_requests(self.previous_batch.sample_state) + self._send_kv_async(self.previous_batch.all_requests) if self.drafter is not None and self.use_spec_decode and should_process_previous_batch: # Cleanup previous draft resources used in the draft model @@ -1829,9 +1830,6 @@ class PyExecutor: self._update_request_states(scheduled_batch) - ctx_transmission_reqs = self._send_kv_async( - scheduled_batch.all_requests()) - if self.previous_batch is not None and should_process_previous_batch: self._process_previous_batch() else: @@ -1846,7 +1844,7 @@ class PyExecutor: sample_state=sample_state, iter_start_time=iter_start_time, iter_stats=iter_stats, - ctx_transmission_reqs=ctx_transmission_reqs) + all_requests=scheduled_batch.all_requests()) elif not can_queue_this_rank: # If the batch is empty on this rank, we need to clear the previous batch. self.previous_batch = None @@ -1949,10 +1947,6 @@ class PyExecutor: return result_tensors, num_accepted_tokens def _process_previous_batch(self): - if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs: - for req in self.previous_batch.ctx_transmission_reqs: - req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS - self._handle_canceled_requests() finished_requests = self._handle_responses() scheduled_requests = self.previous_batch.sample_state.scheduled_requests diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 76fcb4f096..0b6a417793 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -448,15 +448,13 @@ class BaseWorker(GenerationExecutor): context_phase_params = request.disaggregated_params.get_context_phase_params( ) - if self._is_pytorch_backend: - if not self.llm_args.disable_overlap_scheduler: - is_disaggregated = self.engine.kv_cache_transceiver is not None - if is_disaggregated and ( - request_type - == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): - raise ValueError( - "Context only requests are not supported in pytorch backend when overlap is enabled." - ) + if self._is_pytorch_backend and not self.llm_args.disable_overlap_scheduler \ + and self.llm_args.kv_cache_config.enable_block_reuse \ + and self.engine.kv_cache_transceiver is not None \ + and request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY: + raise ValueError( + "Context only requests are not supported in pytorch backend when overlap is enabled with block reuse." + ) assert request.id is not None diff --git a/tests/integration/defs/.test_durations b/tests/integration/defs/.test_durations index d0abdb6eae..6da286cf0c 100644 --- a/tests/integration/defs/.test_durations +++ b/tests/integration/defs/.test_durations @@ -144,8 +144,8 @@ "accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend": 71.2399792142678, "accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]": 286.7775873204227537, "accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]": 286.6778334858827293, - "accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False]": 781.7928658421151, - "accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]": 270.3750694899354, + "accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]": 781.7928658421151, + "accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]": 270.3750694899354, "accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2]": 195.4896494857967, "accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4]": 205.93911361903884, "accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2]": 188.56422709790058, diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index f05e327c9e..8f67aa9074 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -524,20 +524,26 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): @skip_pre_hopper @pytest.mark.skip_less_device(2) - @pytest.mark.parametrize("disable_overlap_scheduler", [False, True]) + @pytest.mark.parametrize("ctx_disable_overlap_scheduler", [False, True]) + @pytest.mark.parametrize("gen_disable_overlap_scheduler", [False, True]) @pytest.mark.parametrize("ctx_enable_block_reuse", [True, False]) @pytest.mark.parametrize("gen_enable_block_reuse", [True, False]) - def test_auto_dtype(self, disable_overlap_scheduler, ctx_enable_block_reuse, + def test_auto_dtype(self, ctx_disable_overlap_scheduler, + gen_disable_overlap_scheduler, ctx_enable_block_reuse, gen_enable_block_reuse): + if ctx_enable_block_reuse and not ctx_disable_overlap_scheduler: + pytest.skip( + "Skip this test because overlap scheduler is not supported with block reuse for context server" + ) ctx_server_config = { - "disable_overlap_scheduler": True, + "disable_overlap_scheduler": ctx_disable_overlap_scheduler, "kv_cache_config": { "enable_block_reuse": ctx_enable_block_reuse } } ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"} gen_server_config = { - "disable_overlap_scheduler": disable_overlap_scheduler, + "disable_overlap_scheduler": gen_disable_overlap_scheduler, "kv_cache_config": { "enable_block_reuse": gen_enable_block_reuse } diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml index 7cfb488871..f7e879bb4c 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1cp2_deepseek_v3_lite_bf16_tllm_gen.yaml @@ -3,7 +3,6 @@ port: 8000 model: DeepSeek-V3-Lite/bf16 free_gpu_memory_fraction: 0.25 backend: "pytorch" -disable_overlap_scheduler: True cuda_graph_config: null context_servers: num_instances: 1 diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml index 1110cc2f2a..1d1535d1ae 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp4_gentp4_deepseek_r1_v2_fp4_tllm.yaml @@ -18,7 +18,6 @@ context_servers: enable_block_reuse: false free_gpu_memory_fraction: 0.80 dtype: fp8 - disable_overlap_scheduler: true moe_config: backend: TRTLLM cuda_graph_config: null @@ -44,7 +43,6 @@ generation_servers: enable_block_reuse: false free_gpu_memory_fraction: 0.80 dtype: fp8 - disable_overlap_scheduler: true moe_config: backend: TRTLLM cuda_graph_config: diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_deepseek_v3_lite_empty_batch.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_deepseek_v3_lite_empty_batch.yaml index 409a314ec4..920fa0f053 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_deepseek_v3_lite_empty_batch.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_deepseek_v3_lite_empty_batch.yaml @@ -16,7 +16,6 @@ context_servers: pipeline_parallel_size: 1 print_iter_log: true cuda_graph_config: null - disable_overlap_scheduler: true kv_cache_config: enable_block_reuse: false free_gpu_memory_fraction: 0.05 diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml index 55990bbaa6..d51ffabf8a 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml @@ -12,9 +12,9 @@ context_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 kv_cache_config: + enable_block_reuse: False free_gpu_memory_fraction: 0.2 enable_partial_reuse: False - disable_overlap_scheduler: True cache_transceiver_config: backend: DEFAULT urls: @@ -27,9 +27,9 @@ generation_servers: max_num_tokens: 4096 max_seq_len: 4096 kv_cache_config: + enable_block_reuse: False free_gpu_memory_fraction: 0.2 enable_partial_reuse: False - disable_overlap_scheduler: False cache_transceiver_config: backend: DEFAULT urls: diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 077fe6a61a..2098f7b010 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -304,8 +304,8 @@ accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[F accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False] -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False] -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] diff --git a/tests/integration/test_lists/qa/llm_function_core_sanity.txt b/tests/integration/test_lists/qa/llm_function_core_sanity.txt index a691fc1d5a..5b5c801d70 100644 --- a/tests/integration/test_lists/qa/llm_function_core_sanity.txt +++ b/tests/integration/test_lists/qa/llm_function_core_sanity.txt @@ -162,8 +162,8 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False] -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=2] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=2-ctx_pp=2] diff --git a/tests/integration/test_lists/qa/llm_function_rtx6k.txt b/tests/integration/test_lists/qa/llm_function_rtx6k.txt index 395f3f2a5e..c9b42399eb 100644 --- a/tests/integration/test_lists/qa/llm_function_rtx6k.txt +++ b/tests/integration/test_lists/qa/llm_function_rtx6k.txt @@ -176,8 +176,8 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False] -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False] -accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True] +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index aa823a0450..398a1827fb 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -33,14 +33,22 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False] - - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-False-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-False-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-True-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-True-True-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-False-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False-True] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-False] + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True] - unittest/llmapi/apps/test_disagg_serving_perf_metrics.py - disaggregated/test_disaggregated.py::test_disaggregated_cancel_large_context_requests[DeepSeek-V3-Lite-bf16] # llmapi