diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index c8267d5745..1f03b86565 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1353,7 +1353,7 @@ class PyExecutor: if self.kv_cache_transceiver: self._send_kv_async( - previous_batch.scheduled_ctx_reqs) + previous_batch.finished_ctx_reqs) self._handle_canceled_requests() self._handle_logits_communication( diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index d7be3c6272..f7719da646 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -453,6 +453,7 @@ def launch_disaggregated_llm( def run_parallel_test(model_name: str, model_path: str, + *, ctx_pp: int, ctx_tp: int, gen_pp: int, @@ -771,8 +772,15 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): def test_tp_pp_symmetric(self, tp, pp, testset): if tp * pp * 2 > get_device_count(): pytest.skip(f"Not enough devices for tp={tp}*pp={pp} test") - return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp, - tp, 1, 1, [get_accuracy_task(testset)]) + return run_parallel_test(self.MODEL_NAME, + self.MODEL_PATH, + ctx_pp=pp, + ctx_tp=tp, + gen_pp=pp, + gen_tp=tp, + ctx_instances=1, + gen_instances=1, + test_sets=[get_accuracy_task(testset)]) @parametrize_with_ids("ctx_pp", [2, 4]) @parametrize_with_ids("gen_tp", [1, 2]) @@ -781,13 +789,27 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): if ctx_pp + gen_tp > get_device_count(): pytest.skip( f"Not enough devices for ctx_pp={ctx_pp}+gen_tp={gen_tp} test") - return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1, - gen_tp, 1, 1, [get_accuracy_task(testset)]) + return run_parallel_test(self.MODEL_NAME, + self.MODEL_PATH, + ctx_pp=ctx_pp, + ctx_tp=1, + gen_pp=1, + gen_tp=gen_tp, + ctx_instances=1, + gen_instances=1, + test_sets=[get_accuracy_task(testset)]) @pytest.mark.parametrize("testset", ["GSM8K", "MMLU"]) def test_multi_instance(self, testset): - return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, 1, 1, 1, 1, - 2, 2, [get_accuracy_task(testset)]) + return run_parallel_test(self.MODEL_NAME, + self.MODEL_PATH, + ctx_pp=1, + ctx_tp=1, + gen_pp=1, + gen_tp=1, + ctx_instances=2, + gen_instances=2, + test_sets=[get_accuracy_task(testset)]) class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness): @@ -1313,10 +1335,16 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) - def test_chunked_prefill(self): + def _test_chunked_prefill_helper(self, *, ctx_pp: int): # bs=1 will stabilize the result, but the test will be much slower max_batch_size = 32 + + kv_cache_config = { + "enable_block_reuse": True if ctx_pp == 1 else False, + } + ctx_server_config = { + "pipeline_parallel_size": ctx_pp, "disable_overlap_scheduler": True, "cuda_graph_config": None, "cache_transceiver_config": { @@ -1325,6 +1353,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): "enable_chunked_prefill": True, "max_num_tokens": 256, "max_batch_size": max_batch_size, + "kv_cache_config": kv_cache_config, } gen_server_config = { "cuda_graph_config": None, @@ -1351,6 +1380,16 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness): self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) + @skip_pre_hopper + @pytest.mark.skip_less_device(2) + def test_chunked_prefill(self): + self._test_chunked_prefill_helper(ctx_pp=1) + + @skip_pre_hopper + @pytest.mark.skip_less_device(4) + def test_chunked_prefill_ctx_pp2(self): + self._test_chunked_prefill_helper(ctx_pp=2) + @skip_pre_blackwell @pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index c308bad8d7..398342c0a3 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -127,6 +127,7 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill_ctx_pp2 - accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_bf16_4gpu[tp4ep4_cudagraph_overlap] - disaggregated/test_auto_scaling.py::test_service_discovery[etcd-round_robin] - disaggregated/test_auto_scaling.py::test_worker_restart[etcd-load_balancing]