diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index f7719da646..57168aa581 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -450,6 +450,30 @@ def launch_disaggregated_llm( _show_kvcache_time(kv_cache_perf_dir) _get_perf_metrics() + # Gracefully shut down all server processes + all_processes = list( + itertools.chain(ctx_processes, gen_processes, server_processes)) + + # SIGTERM triggers llm.shutdown() inside each trtllm-serve, cleaning up the executor and MPI workers. + for process in all_processes: + if process.poll() is None: + process.terminate() + + # Wait up to 5s total, then SIGKILL any process that doesn't exit. + # This is a safety net for when llm.shutdown() hangs. + deadline = time.monotonic() + 5 + for process in all_processes: + remaining = max(0, deadline - time.monotonic()) + try: + process.wait(timeout=remaining) + except subprocess.TimeoutExpired: + try: + process.kill() + except ProcessLookupError: + pass # already exited between timeout and kill + except OSError: + pass # process already gone + def run_parallel_test(model_name: str, model_path: str,