[https://nvbugs/5829097][fix] Disaggregated serving: Only send finished context requests to the KV cache transceiver (#11354)

Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
Robin Kobus 2026-02-09 10:11:45 +01:00 committed by GitHub
parent ab73f6ebc6
commit 31db399042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 8 deletions

View File

@ -1353,7 +1353,7 @@ class PyExecutor:
if self.kv_cache_transceiver:
self._send_kv_async(
previous_batch.scheduled_ctx_reqs)
previous_batch.finished_ctx_reqs)
self._handle_canceled_requests()
self._handle_logits_communication(

View File

@ -453,6 +453,7 @@ def launch_disaggregated_llm(
def run_parallel_test(model_name: str,
model_path: str,
*,
ctx_pp: int,
ctx_tp: int,
gen_pp: int,
@ -771,8 +772,15 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
def test_tp_pp_symmetric(self, tp, pp, testset):
if tp * pp * 2 > get_device_count():
pytest.skip(f"Not enough devices for tp={tp}*pp={pp} test")
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
tp, 1, 1, [get_accuracy_task(testset)])
return run_parallel_test(self.MODEL_NAME,
self.MODEL_PATH,
ctx_pp=pp,
ctx_tp=tp,
gen_pp=pp,
gen_tp=tp,
ctx_instances=1,
gen_instances=1,
test_sets=[get_accuracy_task(testset)])
@parametrize_with_ids("ctx_pp", [2, 4])
@parametrize_with_ids("gen_tp", [1, 2])
@ -781,13 +789,27 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
if ctx_pp + gen_tp > get_device_count():
pytest.skip(
f"Not enough devices for ctx_pp={ctx_pp}+gen_tp={gen_tp} test")
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
gen_tp, 1, 1, [get_accuracy_task(testset)])
return run_parallel_test(self.MODEL_NAME,
self.MODEL_PATH,
ctx_pp=ctx_pp,
ctx_tp=1,
gen_pp=1,
gen_tp=gen_tp,
ctx_instances=1,
gen_instances=1,
test_sets=[get_accuracy_task(testset)])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_multi_instance(self, testset):
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, 1, 1, 1, 1,
2, 2, [get_accuracy_task(testset)])
return run_parallel_test(self.MODEL_NAME,
self.MODEL_PATH,
ctx_pp=1,
ctx_tp=1,
gen_pp=1,
gen_tp=1,
ctx_instances=2,
gen_instances=2,
test_sets=[get_accuracy_task(testset)])
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
@ -1313,10 +1335,16 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
def test_chunked_prefill(self):
def _test_chunked_prefill_helper(self, *, ctx_pp: int):
# bs=1 will stabilize the result, but the test will be much slower
max_batch_size = 32
kv_cache_config = {
"enable_block_reuse": True if ctx_pp == 1 else False,
}
ctx_server_config = {
"pipeline_parallel_size": ctx_pp,
"disable_overlap_scheduler": True,
"cuda_graph_config": None,
"cache_transceiver_config": {
@ -1325,6 +1353,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
"enable_chunked_prefill": True,
"max_num_tokens": 256,
"max_batch_size": max_batch_size,
"kv_cache_config": kv_cache_config,
}
gen_server_config = {
"cuda_graph_config": None,
@ -1351,6 +1380,16 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
@skip_pre_hopper
@pytest.mark.skip_less_device(2)
def test_chunked_prefill(self):
self._test_chunked_prefill_helper(ctx_pp=1)
@skip_pre_hopper
@pytest.mark.skip_less_device(4)
def test_chunked_prefill_ctx_pp2(self):
self._test_chunked_prefill_helper(ctx_pp=2)
@skip_pre_blackwell
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)

View File

@ -127,6 +127,7 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill_ctx_pp2
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_bf16_4gpu[tp4ep4_cudagraph_overlap]
- disaggregated/test_auto_scaling.py::test_service_discovery[etcd-round_robin]
- disaggregated/test_auto_scaling.py::test_worker_restart[etcd-load_balancing]