mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5550671][fix] fix disagg-serving multinodes test failure (#8307)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
18a534d2b4
commit
982d4b65e8
@ -1763,8 +1763,10 @@ def test_trtllm_multimodal_benchmark_serving(llm_root, llm_venv):
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
@pytest.mark.parametrize("gen_config", ["gen_tp2pp1", "gen_tp1pp2"])
|
||||
@pytest.mark.parametrize("ctx_config", ["ctx_tp2pp1", "ctx_tp1pp2"])
|
||||
@pytest.mark.parametrize("gen_config",
|
||||
["gen_tp2pp1", "gen_tp1pp2", "gen_tp1pp1"])
|
||||
@pytest.mark.parametrize("ctx_config",
|
||||
["ctx_tp2pp1", "ctx_tp1pp2", "ctx_tp1pp1"])
|
||||
def test_openai_disagg_multi_nodes_completion(llm_root, llm_venv, ctx_config,
|
||||
gen_config):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
|
||||
@ -28,9 +28,16 @@ DISAGG_SERVER_PORT = 8000
|
||||
# Exclude the current node from the node list, then return other nodes by idx
|
||||
def get_the_other_host(idx=0):
|
||||
assert len(NODE_LIST) >= 2
|
||||
assert socket.gethostname() in NODE_LIST
|
||||
node_list = NODE_LIST.copy()
|
||||
node_list.remove(socket.gethostname())
|
||||
curr_host = socket.gethostname()
|
||||
if curr_host in NODE_LIST:
|
||||
# gethostname returns the exact node name in node list
|
||||
node_list.remove(curr_host)
|
||||
else:
|
||||
# gethostname returns the full domain
|
||||
curr_host = curr_host.split('.')[0]
|
||||
assert curr_host in node_list
|
||||
node_list.remove(curr_host)
|
||||
return node_list[idx]
|
||||
|
||||
|
||||
@ -89,8 +96,6 @@ def gen_tp_pp_size(request):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
|
||||
host = socket.gethostname()
|
||||
assert host in NODE_LIST
|
||||
extra_config = {
|
||||
"cache_transceiver_config": {
|
||||
"backend": "UCX"
|
||||
@ -111,7 +116,7 @@ def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
|
||||
cli_args=args,
|
||||
host="0.0.0.0",
|
||||
env=env(),
|
||||
llmapi_launch=True,
|
||||
llmapi_launch=False,
|
||||
rank=RANK % SLURM_NTASKS_PER_NODE,
|
||||
extra_config=extra_config) as server:
|
||||
yield server
|
||||
@ -125,6 +130,7 @@ def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
|
||||
cli_args=args,
|
||||
host="0.0.0.0",
|
||||
env=env(),
|
||||
llmapi_launch=False,
|
||||
rank=RANK % SLURM_NTASKS_PER_NODE,
|
||||
extra_config=extra_config) as server:
|
||||
yield server
|
||||
@ -169,13 +175,11 @@ def disagg_server(worker: RemoteOpenAIServer):
|
||||
with RemoteDisaggOpenAIServer(ctx_servers=[ctx_url],
|
||||
gen_servers=[gen_url],
|
||||
port=DISAGG_SERVER_PORT,
|
||||
llmapi_launch=True,
|
||||
llmapi_launch=False,
|
||||
env=env()) as server:
|
||||
yield server
|
||||
else:
|
||||
print(f"skipping disagg_server for rank {RANK} node rank {NODE_RANK}")
|
||||
url = f"http://{get_the_other_host(0)}:{DISAGG_SERVER_PORT}/health/"
|
||||
wait_for_endpoint_ready(url, 60)
|
||||
yield None
|
||||
|
||||
|
||||
@ -190,6 +194,9 @@ def client(disagg_server: RemoteDisaggOpenAIServer):
|
||||
|
||||
def test_completion(client: openai.OpenAI,
|
||||
disagg_server: RemoteDisaggOpenAIServer, model_name: str):
|
||||
if len(NODE_LIST) != 2:
|
||||
pytest.skip("This test is only expected to run with 2 nodes")
|
||||
return
|
||||
if is_pytest_node():
|
||||
print(f"running test_completion on rank {RANK} node rank {NODE_RANK}")
|
||||
prompt = "What is the result of 1+1? Answer in one word: "
|
||||
@ -208,7 +215,8 @@ def test_completion(client: openai.OpenAI,
|
||||
elif is_gen_node():
|
||||
# keep gen workers alive until the test ends, again we hope the NODE_LIST is ordered by NODE_RANK
|
||||
url = f"http://{get_the_other_host(0)}:{DISAGG_SERVER_PORT}/health/"
|
||||
wait_for_endpoint_down(url, 60)
|
||||
wait_for_endpoint_ready(url)
|
||||
wait_for_endpoint_down(url)
|
||||
assert True
|
||||
else:
|
||||
assert True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user