[https://nvbugs/5550671][fix] fix disagg-serving multinodes test failure (#8307)

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2025-10-14 14:01:00 +08:00 committed by Yanchao Lu
parent 18a534d2b4
commit 982d4b65e8
2 changed files with 21 additions and 11 deletions

View File

@ -1763,8 +1763,10 @@ def test_trtllm_multimodal_benchmark_serving(llm_root, llm_venv):
@pytest.mark.skip_less_device(4)
@pytest.mark.skip_less_device_memory(40000)
@pytest.mark.parametrize("gen_config", ["gen_tp2pp1", "gen_tp1pp2"])
@pytest.mark.parametrize("ctx_config", ["ctx_tp2pp1", "ctx_tp1pp2"])
@pytest.mark.parametrize("gen_config",
["gen_tp2pp1", "gen_tp1pp2", "gen_tp1pp1"])
@pytest.mark.parametrize("ctx_config",
["ctx_tp2pp1", "ctx_tp1pp2", "ctx_tp1pp1"])
def test_openai_disagg_multi_nodes_completion(llm_root, llm_venv, ctx_config,
gen_config):
test_root = unittest_path() / "llmapi" / "apps"

View File

@ -28,9 +28,16 @@ DISAGG_SERVER_PORT = 8000
# Exclude the current node from the node list, then return other nodes by idx
def get_the_other_host(idx=0):
assert len(NODE_LIST) >= 2
assert socket.gethostname() in NODE_LIST
node_list = NODE_LIST.copy()
node_list.remove(socket.gethostname())
curr_host = socket.gethostname()
if curr_host in NODE_LIST:
# gethostname returns the exact node name in node list
node_list.remove(curr_host)
else:
# gethostname returns the full domain
curr_host = curr_host.split('.')[0]
assert curr_host in node_list
node_list.remove(curr_host)
return node_list[idx]
@ -89,8 +96,6 @@ def gen_tp_pp_size(request):
@pytest.fixture(scope="module")
def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
host = socket.gethostname()
assert host in NODE_LIST
extra_config = {
"cache_transceiver_config": {
"backend": "UCX"
@ -111,7 +116,7 @@ def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
cli_args=args,
host="0.0.0.0",
env=env(),
llmapi_launch=True,
llmapi_launch=False,
rank=RANK % SLURM_NTASKS_PER_NODE,
extra_config=extra_config) as server:
yield server
@ -125,6 +130,7 @@ def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
cli_args=args,
host="0.0.0.0",
env=env(),
llmapi_launch=False,
rank=RANK % SLURM_NTASKS_PER_NODE,
extra_config=extra_config) as server:
yield server
@ -169,13 +175,11 @@ def disagg_server(worker: RemoteOpenAIServer):
with RemoteDisaggOpenAIServer(ctx_servers=[ctx_url],
gen_servers=[gen_url],
port=DISAGG_SERVER_PORT,
llmapi_launch=True,
llmapi_launch=False,
env=env()) as server:
yield server
else:
print(f"skipping disagg_server for rank {RANK} node rank {NODE_RANK}")
url = f"http://{get_the_other_host(0)}:{DISAGG_SERVER_PORT}/health/"
wait_for_endpoint_ready(url, 60)
yield None
@ -190,6 +194,9 @@ def client(disagg_server: RemoteDisaggOpenAIServer):
def test_completion(client: openai.OpenAI,
disagg_server: RemoteDisaggOpenAIServer, model_name: str):
if len(NODE_LIST) != 2:
pytest.skip("This test is only expected to run with 2 nodes")
return
if is_pytest_node():
print(f"running test_completion on rank {RANK} node rank {NODE_RANK}")
prompt = "What is the result of 1+1? Answer in one word: "
@ -208,7 +215,8 @@ def test_completion(client: openai.OpenAI,
elif is_gen_node():
# keep gen workers alive until the test ends, again we hope the NODE_LIST is ordered by NODE_RANK
url = f"http://{get_the_other_host(0)}:{DISAGG_SERVER_PORT}/health/"
wait_for_endpoint_down(url, 60)
wait_for_endpoint_ready(url)
wait_for_endpoint_down(url)
assert True
else:
assert True