mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][chore] Unit test for disagg gen cancellation (#11108)
Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
This commit is contained in:
parent
ea81a03dd1
commit
c68d916b6f
@ -23,8 +23,8 @@ from .lora_test_utils import (
|
||||
create_mock_nemo_lora_checkpoint, compare_cuda_graph_lora_params_filler,
|
||||
CUDAGraphLoRATestParams, test_lora_with_and_without_cuda_graph)
|
||||
from .test_llm import (_test_llm_capture_request_error, get_model_path,
|
||||
global_kvcache_config, llama_model_path,
|
||||
llm_get_stats_async_test_harness,
|
||||
global_kvcache_config, global_kvcache_config_no_reuse,
|
||||
llama_model_path, llm_get_stats_async_test_harness,
|
||||
llm_get_stats_test_harness,
|
||||
llm_return_logprobs_test_harness, llm_test_harness,
|
||||
prompts, run_llm_abort_request,
|
||||
@ -1313,3 +1313,228 @@ def test_llm_context_only_timed_out_kv_cache_exhausted(sender_future_timeout_ms,
|
||||
final_used_num_blocks = results[0]["kvCacheStats"]["usedNumBlocks"]
|
||||
|
||||
assert final_used_num_blocks == 0
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@pytest.mark.part0
|
||||
@skip_ray
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_disagg_gen_cancelled():
|
||||
tp_size = 1
|
||||
use_overlap = False
|
||||
enable_iter_req_stats = False
|
||||
|
||||
llm_args_extra = {}
|
||||
|
||||
llm_args_extra.update(
|
||||
dict(enable_iter_perf_stats=True,
|
||||
enable_iter_req_stats=enable_iter_req_stats,
|
||||
disable_overlap_scheduler=not use_overlap))
|
||||
|
||||
llm_ctx = LLM(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config_no_reuse,
|
||||
tensor_parallel_size=tp_size,
|
||||
cache_transceiver_config=CacheTransceiverConfig(
|
||||
backend="UCX", kv_transfer_timeout_ms=1000),
|
||||
**llm_args_extra)
|
||||
|
||||
llm_gen = LLM(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config_no_reuse,
|
||||
tensor_parallel_size=tp_size,
|
||||
cache_transceiver_config=CacheTransceiverConfig(
|
||||
backend="UCX", kv_transfer_timeout_ms=1000),
|
||||
**llm_args_extra)
|
||||
|
||||
try:
|
||||
num_iterations = 10
|
||||
prev_after_free_num_blocks = 0
|
||||
for iter in range(num_iterations):
|
||||
|
||||
max_tokens = 1
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
disaggregated_params = DisaggregatedParams(
|
||||
request_type="context_only")
|
||||
|
||||
prompt = [
|
||||
"lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua "
|
||||
* 10
|
||||
]
|
||||
# Send context-only request
|
||||
ctx_outputs = []
|
||||
for output in llm_ctx.generate(
|
||||
prompt,
|
||||
sampling_params=sampling_params,
|
||||
disaggregated_params=disaggregated_params):
|
||||
ctx_outputs.append(output)
|
||||
|
||||
assert len(ctx_outputs) == 1
|
||||
|
||||
max_tokens = 10000
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
disaggregated_params = ctx_outputs[0].disaggregated_params
|
||||
disaggregated_params.request_type = "generation_only"
|
||||
|
||||
# Send gen-only request
|
||||
gen_output = llm_gen.generate_async(
|
||||
prompt[0],
|
||||
sampling_params=sampling_params,
|
||||
disaggregated_params=disaggregated_params)
|
||||
|
||||
# Sleep a little to have tokens generated, between 0.2 and 0.7 seconds
|
||||
sleep_time = random.uniform(0.2, 0.7)
|
||||
time.sleep(sleep_time)
|
||||
#Abort the generation request
|
||||
gen_output.abort()
|
||||
result = await gen_output.aresult()
|
||||
num_output_tokens = len(result.outputs[0].token_ids)
|
||||
print(f"num output tokens: {num_output_tokens}")
|
||||
assert result.outputs[0].finish_reason == "cancelled"
|
||||
|
||||
# Num check that the number of free/used blocks is as expected
|
||||
time.sleep(1.)
|
||||
max_retries = 10
|
||||
for _ in range(max_retries):
|
||||
results = llm_gen.get_stats(2)
|
||||
print("len(results):", len(results))
|
||||
if len(results) == num_output_tokens - (1 if iter == 0 else 0):
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
pytest.fail(
|
||||
f"Failed to get stats with len=={num_output_tokens - 1} after {max_retries} retries"
|
||||
)
|
||||
|
||||
after_used_num_blocks = results[-1]["kvCacheStats"]["usedNumBlocks"]
|
||||
assert after_used_num_blocks == 0
|
||||
|
||||
after_free_num_blocks = results[-1]["kvCacheStats"]["freeNumBlocks"]
|
||||
# Check that number of free blocks stays the same
|
||||
if iter > 0:
|
||||
assert after_free_num_blocks == prev_after_free_num_blocks
|
||||
|
||||
# Check that number of free blocks stays the same
|
||||
prev_after_free_num_blocks = after_free_num_blocks
|
||||
finally:
|
||||
llm_ctx.shutdown()
|
||||
llm_gen.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@pytest.mark.part0
|
||||
@skip_ray
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_disagg_streaming_gen_cancelled():
|
||||
tp_size = 1
|
||||
use_overlap = False
|
||||
enable_iter_req_stats = False
|
||||
|
||||
llm_args_extra = {}
|
||||
|
||||
llm_args_extra.update(
|
||||
dict(enable_iter_perf_stats=True,
|
||||
enable_iter_req_stats=enable_iter_req_stats,
|
||||
disable_overlap_scheduler=not use_overlap))
|
||||
|
||||
llm_ctx = LLM(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config_no_reuse,
|
||||
tensor_parallel_size=tp_size,
|
||||
cache_transceiver_config=CacheTransceiverConfig(
|
||||
backend="UCX", kv_transfer_timeout_ms=1000),
|
||||
**llm_args_extra)
|
||||
|
||||
llm_gen = LLM(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config_no_reuse,
|
||||
tensor_parallel_size=tp_size,
|
||||
cache_transceiver_config=CacheTransceiverConfig(
|
||||
backend="UCX", kv_transfer_timeout_ms=1000),
|
||||
**llm_args_extra)
|
||||
|
||||
try:
|
||||
num_iterations = 10
|
||||
num_concurrent_requests = 20
|
||||
prev_after_free_num_blocks = 0
|
||||
for iter in range(num_iterations):
|
||||
|
||||
max_tokens = 1
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
disaggregated_params = DisaggregatedParams(
|
||||
request_type="context_only")
|
||||
|
||||
prompts = [
|
||||
"lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua "
|
||||
* random.randint(5, 15) for _ in range(num_concurrent_requests)
|
||||
]
|
||||
# Send context-only requests
|
||||
ctx_outputs = []
|
||||
for prompt in prompts:
|
||||
for output in llm_ctx.generate(
|
||||
[prompt],
|
||||
sampling_params=sampling_params,
|
||||
disaggregated_params=disaggregated_params):
|
||||
ctx_outputs.append(output)
|
||||
|
||||
assert len(ctx_outputs) == num_concurrent_requests
|
||||
|
||||
max_tokens = 300
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
|
||||
# Send multiple gen-only requests concurrently
|
||||
async def process_request(idx):
|
||||
disagg_params = ctx_outputs[idx].disaggregated_params
|
||||
disagg_params.request_type = "generation_only"
|
||||
|
||||
tokens_out = 0
|
||||
# Ensure we always cancel early to avoid race conditions
|
||||
stop_after_tokens = random.randint(10, 50)
|
||||
finished_reason = None
|
||||
async for gen_output in llm_gen.generate_async(
|
||||
prompts[idx],
|
||||
sampling_params=sampling_params,
|
||||
disaggregated_params=disagg_params,
|
||||
streaming=True):
|
||||
tokens_out += 1
|
||||
finished_reason = gen_output.outputs[0].finish_reason
|
||||
if tokens_out == stop_after_tokens:
|
||||
gen_output.abort()
|
||||
|
||||
return finished_reason, stop_after_tokens
|
||||
|
||||
# Launch all requests concurrently
|
||||
import asyncio
|
||||
results = await asyncio.gather(
|
||||
*[process_request(i) for i in range(num_concurrent_requests)])
|
||||
|
||||
# Verify all requests were cancelled
|
||||
for finished_reason, stop_after_tokens in results:
|
||||
assert finished_reason is not None
|
||||
assert finished_reason == "cancelled"
|
||||
|
||||
# Check that the number of free/used blocks is as expected
|
||||
time.sleep(1.)
|
||||
max_retries = 10
|
||||
for _ in range(max_retries):
|
||||
stats_results = llm_gen.get_stats(2)
|
||||
print("len(stats_results):", len(stats_results))
|
||||
if len(stats_results) > 0:
|
||||
break
|
||||
time.sleep(1)
|
||||
else:
|
||||
pytest.fail(f"Failed to get stats after {max_retries} retries")
|
||||
|
||||
after_used_num_blocks = stats_results[-1]["kvCacheStats"][
|
||||
"usedNumBlocks"]
|
||||
assert after_used_num_blocks == 0
|
||||
|
||||
after_free_num_blocks = stats_results[-1]["kvCacheStats"][
|
||||
"freeNumBlocks"]
|
||||
# Check that number of free blocks stays the same
|
||||
if iter > 0:
|
||||
assert after_free_num_blocks == prev_after_free_num_blocks
|
||||
|
||||
# Check that number of free blocks stays the same
|
||||
prev_after_free_num_blocks = after_free_num_blocks
|
||||
finally:
|
||||
llm_ctx.shutdown()
|
||||
llm_gen.shutdown()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user