From c68d916b6ffd70f390df438814f5289d9315c60d Mon Sep 17 00:00:00 2001 From: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Date: Mon, 9 Feb 2026 14:39:02 -0500 Subject: [PATCH] [None][chore] Unit test for disagg gen cancellation (#11108) Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> --- tests/unittest/llmapi/test_llm_pytorch.py | 229 +++++++++++++++++++++- 1 file changed, 227 insertions(+), 2 deletions(-) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index b7ffce9027..0e52f3000c 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -23,8 +23,8 @@ from .lora_test_utils import ( create_mock_nemo_lora_checkpoint, compare_cuda_graph_lora_params_filler, CUDAGraphLoRATestParams, test_lora_with_and_without_cuda_graph) from .test_llm import (_test_llm_capture_request_error, get_model_path, - global_kvcache_config, llama_model_path, - llm_get_stats_async_test_harness, + global_kvcache_config, global_kvcache_config_no_reuse, + llama_model_path, llm_get_stats_async_test_harness, llm_get_stats_test_harness, llm_return_logprobs_test_harness, llm_test_harness, prompts, run_llm_abort_request, @@ -1313,3 +1313,228 @@ def test_llm_context_only_timed_out_kv_cache_exhausted(sender_future_timeout_ms, final_used_num_blocks = results[0]["kvCacheStats"]["usedNumBlocks"] assert final_used_num_blocks == 0 + + +@pytest.mark.threadleak(enabled=False) +@pytest.mark.part0 +@skip_ray +@pytest.mark.asyncio +async def test_llm_disagg_gen_cancelled(): + tp_size = 1 + use_overlap = False + enable_iter_req_stats = False + + llm_args_extra = {} + + llm_args_extra.update( + dict(enable_iter_perf_stats=True, + enable_iter_req_stats=enable_iter_req_stats, + disable_overlap_scheduler=not use_overlap)) + + llm_ctx = LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config_no_reuse, + tensor_parallel_size=tp_size, + cache_transceiver_config=CacheTransceiverConfig( + backend="UCX", kv_transfer_timeout_ms=1000), + **llm_args_extra) + + llm_gen = LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config_no_reuse, + tensor_parallel_size=tp_size, + cache_transceiver_config=CacheTransceiverConfig( + backend="UCX", kv_transfer_timeout_ms=1000), + **llm_args_extra) + + try: + num_iterations = 10 + prev_after_free_num_blocks = 0 + for iter in range(num_iterations): + + max_tokens = 1 + sampling_params = SamplingParams(max_tokens=max_tokens) + disaggregated_params = DisaggregatedParams( + request_type="context_only") + + prompt = [ + "lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua " + * 10 + ] + # Send context-only request + ctx_outputs = [] + for output in llm_ctx.generate( + prompt, + sampling_params=sampling_params, + disaggregated_params=disaggregated_params): + ctx_outputs.append(output) + + assert len(ctx_outputs) == 1 + + max_tokens = 10000 + sampling_params = SamplingParams(max_tokens=max_tokens, + ignore_eos=True) + disaggregated_params = ctx_outputs[0].disaggregated_params + disaggregated_params.request_type = "generation_only" + + # Send gen-only request + gen_output = llm_gen.generate_async( + prompt[0], + sampling_params=sampling_params, + disaggregated_params=disaggregated_params) + + # Sleep a little to have tokens generated, between 0.2 and 0.7 seconds + sleep_time = random.uniform(0.2, 0.7) + time.sleep(sleep_time) + #Abort the generation request + gen_output.abort() + result = await gen_output.aresult() + num_output_tokens = len(result.outputs[0].token_ids) + print(f"num output tokens: {num_output_tokens}") + assert result.outputs[0].finish_reason == "cancelled" + + # Num check that the number of free/used blocks is as expected + time.sleep(1.) + max_retries = 10 + for _ in range(max_retries): + results = llm_gen.get_stats(2) + print("len(results):", len(results)) + if len(results) == num_output_tokens - (1 if iter == 0 else 0): + break + time.sleep(1) + else: + pytest.fail( + f"Failed to get stats with len=={num_output_tokens - 1} after {max_retries} retries" + ) + + after_used_num_blocks = results[-1]["kvCacheStats"]["usedNumBlocks"] + assert after_used_num_blocks == 0 + + after_free_num_blocks = results[-1]["kvCacheStats"]["freeNumBlocks"] + # Check that number of free blocks stays the same + if iter > 0: + assert after_free_num_blocks == prev_after_free_num_blocks + + # Check that number of free blocks stays the same + prev_after_free_num_blocks = after_free_num_blocks + finally: + llm_ctx.shutdown() + llm_gen.shutdown() + + +@pytest.mark.threadleak(enabled=False) +@pytest.mark.part0 +@skip_ray +@pytest.mark.asyncio +async def test_llm_disagg_streaming_gen_cancelled(): + tp_size = 1 + use_overlap = False + enable_iter_req_stats = False + + llm_args_extra = {} + + llm_args_extra.update( + dict(enable_iter_perf_stats=True, + enable_iter_req_stats=enable_iter_req_stats, + disable_overlap_scheduler=not use_overlap)) + + llm_ctx = LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config_no_reuse, + tensor_parallel_size=tp_size, + cache_transceiver_config=CacheTransceiverConfig( + backend="UCX", kv_transfer_timeout_ms=1000), + **llm_args_extra) + + llm_gen = LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config_no_reuse, + tensor_parallel_size=tp_size, + cache_transceiver_config=CacheTransceiverConfig( + backend="UCX", kv_transfer_timeout_ms=1000), + **llm_args_extra) + + try: + num_iterations = 10 + num_concurrent_requests = 20 + prev_after_free_num_blocks = 0 + for iter in range(num_iterations): + + max_tokens = 1 + sampling_params = SamplingParams(max_tokens=max_tokens) + disaggregated_params = DisaggregatedParams( + request_type="context_only") + + prompts = [ + "lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua " + * random.randint(5, 15) for _ in range(num_concurrent_requests) + ] + # Send context-only requests + ctx_outputs = [] + for prompt in prompts: + for output in llm_ctx.generate( + [prompt], + sampling_params=sampling_params, + disaggregated_params=disaggregated_params): + ctx_outputs.append(output) + + assert len(ctx_outputs) == num_concurrent_requests + + max_tokens = 300 + sampling_params = SamplingParams(max_tokens=max_tokens, + ignore_eos=True) + + # Send multiple gen-only requests concurrently + async def process_request(idx): + disagg_params = ctx_outputs[idx].disaggregated_params + disagg_params.request_type = "generation_only" + + tokens_out = 0 + # Ensure we always cancel early to avoid race conditions + stop_after_tokens = random.randint(10, 50) + finished_reason = None + async for gen_output in llm_gen.generate_async( + prompts[idx], + sampling_params=sampling_params, + disaggregated_params=disagg_params, + streaming=True): + tokens_out += 1 + finished_reason = gen_output.outputs[0].finish_reason + if tokens_out == stop_after_tokens: + gen_output.abort() + + return finished_reason, stop_after_tokens + + # Launch all requests concurrently + import asyncio + results = await asyncio.gather( + *[process_request(i) for i in range(num_concurrent_requests)]) + + # Verify all requests were cancelled + for finished_reason, stop_after_tokens in results: + assert finished_reason is not None + assert finished_reason == "cancelled" + + # Check that the number of free/used blocks is as expected + time.sleep(1.) + max_retries = 10 + for _ in range(max_retries): + stats_results = llm_gen.get_stats(2) + print("len(stats_results):", len(stats_results)) + if len(stats_results) > 0: + break + time.sleep(1) + else: + pytest.fail(f"Failed to get stats after {max_retries} retries") + + after_used_num_blocks = stats_results[-1]["kvCacheStats"][ + "usedNumBlocks"] + assert after_used_num_blocks == 0 + + after_free_num_blocks = stats_results[-1]["kvCacheStats"][ + "freeNumBlocks"] + # Check that number of free blocks stays the same + if iter > 0: + assert after_free_num_blocks == prev_after_free_num_blocks + + # Check that number of free blocks stays the same + prev_after_free_num_blocks = after_free_num_blocks + finally: + llm_ctx.shutdown() + llm_gen.shutdown()