[None][chore] Unit test for disagg gen cancellation (#11108)

Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
This commit is contained in:
Patrice Castonguay 2026-02-09 14:39:02 -05:00 committed by GitHub
parent ea81a03dd1
commit c68d916b6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,8 +23,8 @@ from .lora_test_utils import (
create_mock_nemo_lora_checkpoint, compare_cuda_graph_lora_params_filler,
CUDAGraphLoRATestParams, test_lora_with_and_without_cuda_graph)
from .test_llm import (_test_llm_capture_request_error, get_model_path,
global_kvcache_config, llama_model_path,
llm_get_stats_async_test_harness,
global_kvcache_config, global_kvcache_config_no_reuse,
llama_model_path, llm_get_stats_async_test_harness,
llm_get_stats_test_harness,
llm_return_logprobs_test_harness, llm_test_harness,
prompts, run_llm_abort_request,
@ -1313,3 +1313,228 @@ def test_llm_context_only_timed_out_kv_cache_exhausted(sender_future_timeout_ms,
final_used_num_blocks = results[0]["kvCacheStats"]["usedNumBlocks"]
assert final_used_num_blocks == 0
@pytest.mark.threadleak(enabled=False)
@pytest.mark.part0
@skip_ray
@pytest.mark.asyncio
async def test_llm_disagg_gen_cancelled():
tp_size = 1
use_overlap = False
enable_iter_req_stats = False
llm_args_extra = {}
llm_args_extra.update(
dict(enable_iter_perf_stats=True,
enable_iter_req_stats=enable_iter_req_stats,
disable_overlap_scheduler=not use_overlap))
llm_ctx = LLM(model=llama_model_path,
kv_cache_config=global_kvcache_config_no_reuse,
tensor_parallel_size=tp_size,
cache_transceiver_config=CacheTransceiverConfig(
backend="UCX", kv_transfer_timeout_ms=1000),
**llm_args_extra)
llm_gen = LLM(model=llama_model_path,
kv_cache_config=global_kvcache_config_no_reuse,
tensor_parallel_size=tp_size,
cache_transceiver_config=CacheTransceiverConfig(
backend="UCX", kv_transfer_timeout_ms=1000),
**llm_args_extra)
try:
num_iterations = 10
prev_after_free_num_blocks = 0
for iter in range(num_iterations):
max_tokens = 1
sampling_params = SamplingParams(max_tokens=max_tokens)
disaggregated_params = DisaggregatedParams(
request_type="context_only")
prompt = [
"lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua "
* 10
]
# Send context-only request
ctx_outputs = []
for output in llm_ctx.generate(
prompt,
sampling_params=sampling_params,
disaggregated_params=disaggregated_params):
ctx_outputs.append(output)
assert len(ctx_outputs) == 1
max_tokens = 10000
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True)
disaggregated_params = ctx_outputs[0].disaggregated_params
disaggregated_params.request_type = "generation_only"
# Send gen-only request
gen_output = llm_gen.generate_async(
prompt[0],
sampling_params=sampling_params,
disaggregated_params=disaggregated_params)
# Sleep a little to have tokens generated, between 0.2 and 0.7 seconds
sleep_time = random.uniform(0.2, 0.7)
time.sleep(sleep_time)
#Abort the generation request
gen_output.abort()
result = await gen_output.aresult()
num_output_tokens = len(result.outputs[0].token_ids)
print(f"num output tokens: {num_output_tokens}")
assert result.outputs[0].finish_reason == "cancelled"
# Num check that the number of free/used blocks is as expected
time.sleep(1.)
max_retries = 10
for _ in range(max_retries):
results = llm_gen.get_stats(2)
print("len(results):", len(results))
if len(results) == num_output_tokens - (1 if iter == 0 else 0):
break
time.sleep(1)
else:
pytest.fail(
f"Failed to get stats with len=={num_output_tokens - 1} after {max_retries} retries"
)
after_used_num_blocks = results[-1]["kvCacheStats"]["usedNumBlocks"]
assert after_used_num_blocks == 0
after_free_num_blocks = results[-1]["kvCacheStats"]["freeNumBlocks"]
# Check that number of free blocks stays the same
if iter > 0:
assert after_free_num_blocks == prev_after_free_num_blocks
# Check that number of free blocks stays the same
prev_after_free_num_blocks = after_free_num_blocks
finally:
llm_ctx.shutdown()
llm_gen.shutdown()
@pytest.mark.threadleak(enabled=False)
@pytest.mark.part0
@skip_ray
@pytest.mark.asyncio
async def test_llm_disagg_streaming_gen_cancelled():
tp_size = 1
use_overlap = False
enable_iter_req_stats = False
llm_args_extra = {}
llm_args_extra.update(
dict(enable_iter_perf_stats=True,
enable_iter_req_stats=enable_iter_req_stats,
disable_overlap_scheduler=not use_overlap))
llm_ctx = LLM(model=llama_model_path,
kv_cache_config=global_kvcache_config_no_reuse,
tensor_parallel_size=tp_size,
cache_transceiver_config=CacheTransceiverConfig(
backend="UCX", kv_transfer_timeout_ms=1000),
**llm_args_extra)
llm_gen = LLM(model=llama_model_path,
kv_cache_config=global_kvcache_config_no_reuse,
tensor_parallel_size=tp_size,
cache_transceiver_config=CacheTransceiverConfig(
backend="UCX", kv_transfer_timeout_ms=1000),
**llm_args_extra)
try:
num_iterations = 10
num_concurrent_requests = 20
prev_after_free_num_blocks = 0
for iter in range(num_iterations):
max_tokens = 1
sampling_params = SamplingParams(max_tokens=max_tokens)
disaggregated_params = DisaggregatedParams(
request_type="context_only")
prompts = [
"lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua "
* random.randint(5, 15) for _ in range(num_concurrent_requests)
]
# Send context-only requests
ctx_outputs = []
for prompt in prompts:
for output in llm_ctx.generate(
[prompt],
sampling_params=sampling_params,
disaggregated_params=disaggregated_params):
ctx_outputs.append(output)
assert len(ctx_outputs) == num_concurrent_requests
max_tokens = 300
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True)
# Send multiple gen-only requests concurrently
async def process_request(idx):
disagg_params = ctx_outputs[idx].disaggregated_params
disagg_params.request_type = "generation_only"
tokens_out = 0
# Ensure we always cancel early to avoid race conditions
stop_after_tokens = random.randint(10, 50)
finished_reason = None
async for gen_output in llm_gen.generate_async(
prompts[idx],
sampling_params=sampling_params,
disaggregated_params=disagg_params,
streaming=True):
tokens_out += 1
finished_reason = gen_output.outputs[0].finish_reason
if tokens_out == stop_after_tokens:
gen_output.abort()
return finished_reason, stop_after_tokens
# Launch all requests concurrently
import asyncio
results = await asyncio.gather(
*[process_request(i) for i in range(num_concurrent_requests)])
# Verify all requests were cancelled
for finished_reason, stop_after_tokens in results:
assert finished_reason is not None
assert finished_reason == "cancelled"
# Check that the number of free/used blocks is as expected
time.sleep(1.)
max_retries = 10
for _ in range(max_retries):
stats_results = llm_gen.get_stats(2)
print("len(stats_results):", len(stats_results))
if len(stats_results) > 0:
break
time.sleep(1)
else:
pytest.fail(f"Failed to get stats after {max_retries} retries")
after_used_num_blocks = stats_results[-1]["kvCacheStats"][
"usedNumBlocks"]
assert after_used_num_blocks == 0
after_free_num_blocks = stats_results[-1]["kvCacheStats"][
"freeNumBlocks"]
# Check that number of free blocks stays the same
if iter > 0:
assert after_free_num_blocks == prev_after_free_num_blocks
# Check that number of free blocks stays the same
prev_after_free_num_blocks = after_free_num_blocks
finally:
llm_ctx.shutdown()
llm_gen.shutdown()