mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[https://nvbugs/5829097][fix] Disaggregated serving: Only send finished context requests to the KV cache transceiver (#11354)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
ab73f6ebc6
commit
31db399042
@ -1353,7 +1353,7 @@ class PyExecutor:
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._send_kv_async(
|
||||
previous_batch.scheduled_ctx_reqs)
|
||||
previous_batch.finished_ctx_reqs)
|
||||
self._handle_canceled_requests()
|
||||
|
||||
self._handle_logits_communication(
|
||||
|
||||
@ -453,6 +453,7 @@ def launch_disaggregated_llm(
|
||||
|
||||
def run_parallel_test(model_name: str,
|
||||
model_path: str,
|
||||
*,
|
||||
ctx_pp: int,
|
||||
ctx_tp: int,
|
||||
gen_pp: int,
|
||||
@ -771,8 +772,15 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
def test_tp_pp_symmetric(self, tp, pp, testset):
|
||||
if tp * pp * 2 > get_device_count():
|
||||
pytest.skip(f"Not enough devices for tp={tp}*pp={pp} test")
|
||||
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
|
||||
tp, 1, 1, [get_accuracy_task(testset)])
|
||||
return run_parallel_test(self.MODEL_NAME,
|
||||
self.MODEL_PATH,
|
||||
ctx_pp=pp,
|
||||
ctx_tp=tp,
|
||||
gen_pp=pp,
|
||||
gen_tp=tp,
|
||||
ctx_instances=1,
|
||||
gen_instances=1,
|
||||
test_sets=[get_accuracy_task(testset)])
|
||||
|
||||
@parametrize_with_ids("ctx_pp", [2, 4])
|
||||
@parametrize_with_ids("gen_tp", [1, 2])
|
||||
@ -781,13 +789,27 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
if ctx_pp + gen_tp > get_device_count():
|
||||
pytest.skip(
|
||||
f"Not enough devices for ctx_pp={ctx_pp}+gen_tp={gen_tp} test")
|
||||
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
|
||||
gen_tp, 1, 1, [get_accuracy_task(testset)])
|
||||
return run_parallel_test(self.MODEL_NAME,
|
||||
self.MODEL_PATH,
|
||||
ctx_pp=ctx_pp,
|
||||
ctx_tp=1,
|
||||
gen_pp=1,
|
||||
gen_tp=gen_tp,
|
||||
ctx_instances=1,
|
||||
gen_instances=1,
|
||||
test_sets=[get_accuracy_task(testset)])
|
||||
|
||||
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
|
||||
def test_multi_instance(self, testset):
|
||||
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, 1, 1, 1, 1,
|
||||
2, 2, [get_accuracy_task(testset)])
|
||||
return run_parallel_test(self.MODEL_NAME,
|
||||
self.MODEL_PATH,
|
||||
ctx_pp=1,
|
||||
ctx_tp=1,
|
||||
gen_pp=1,
|
||||
gen_tp=1,
|
||||
ctx_instances=2,
|
||||
gen_instances=2,
|
||||
test_sets=[get_accuracy_task(testset)])
|
||||
|
||||
|
||||
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
|
||||
@ -1313,10 +1335,16 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
self.MODEL_PATH) as llm:
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
def test_chunked_prefill(self):
|
||||
def _test_chunked_prefill_helper(self, *, ctx_pp: int):
|
||||
# bs=1 will stabilize the result, but the test will be much slower
|
||||
max_batch_size = 32
|
||||
|
||||
kv_cache_config = {
|
||||
"enable_block_reuse": True if ctx_pp == 1 else False,
|
||||
}
|
||||
|
||||
ctx_server_config = {
|
||||
"pipeline_parallel_size": ctx_pp,
|
||||
"disable_overlap_scheduler": True,
|
||||
"cuda_graph_config": None,
|
||||
"cache_transceiver_config": {
|
||||
@ -1325,6 +1353,7 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_tokens": 256,
|
||||
"max_batch_size": max_batch_size,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
}
|
||||
gen_server_config = {
|
||||
"cuda_graph_config": None,
|
||||
@ -1351,6 +1380,16 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
self.MODEL_PATH) as llm:
|
||||
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device(2)
|
||||
def test_chunked_prefill(self):
|
||||
self._test_chunked_prefill_helper(ctx_pp=1)
|
||||
|
||||
@skip_pre_hopper
|
||||
@pytest.mark.skip_less_device(4)
|
||||
def test_chunked_prefill_ctx_pp2(self):
|
||||
self._test_chunked_prefill_helper(ctx_pp=2)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
|
||||
|
||||
@ -127,6 +127,7 @@ l0_dgx_h100:
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
|
||||
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill_ctx_pp2
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_bf16_4gpu[tp4ep4_cudagraph_overlap]
|
||||
- disaggregated/test_auto_scaling.py::test_service_discovery[etcd-round_robin]
|
||||
- disaggregated/test_auto_scaling.py::test_worker_restart[etcd-load_balancing]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user