[ROCm][P/D][MORI][BugFix] Ensure correct api is used when making requests to prefill / decode nodes (#39835)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
rasmith
2026-04-21 19:48:25 -05:00
committed by GitHub
parent 46794958f0
commit cefa5281a7
2 changed files with 28 additions and 13 deletions
@@ -12,7 +12,7 @@ import aiohttp
import msgpack
import regex as re
import zmq
from quart import Quart, make_response, request
from quart import Quart, Request, make_response, request
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
MoRIIOConstants,
@@ -139,10 +139,13 @@ async def send_request_to_prefill(
return await response.json()
else:
raise RuntimeError(
"send_request_to_prefill response.status != 200response.status = ",
response.status,
error_message = (
f"send_request_to_prefill response ={response},"
f"reason={response.reason}, status={response.status},"
f"method={response.method}, url={response.url},"
f"real_url={response.real_url}"
)
raise RuntimeError(error_message)
async def start_decode_request(endpoint, req_data, request_id):
@@ -163,9 +166,13 @@ async def stream_decode_response(session, response, request_id):
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
raise RuntimeError(
f"decode response.status != 200, status = {response.status}"
error_message = (
f"stream_decode_response response ={response},"
f"reason={response.reason}, status={response.status},"
f"method={response.method}, url={response.url},"
f"real_url={response.real_url}"
)
raise RuntimeError(error_message)
finally:
await session.close()
@@ -175,8 +182,16 @@ def example_round_robin_dp_loader(request_number, dp_size):
@app.route("/v1/completions", methods=["POST"])
async def handle_completions_request():
return await handle_request("/completions", request)
@app.route("/v1/chat/completions", methods=["POST"])
async def handle_request():
async def handle_chat_completions_request():
return await handle_request("/chat/completions", request)
async def handle_request(api: str, request: Request):
try:
with _list_lock:
global request_nums
@@ -230,9 +245,10 @@ async def handle_request():
)
req_data_to_prefill["kv_transfer_params"]["transfer_id"] = transfer_id
prefill_request_url = prefill_instance_endpoint["request_address"] + api
send_prefill_task = asyncio.create_task(
send_request_to_prefill(
prefill_instance_endpoint["request_address"],
prefill_request_url,
req_data_to_prefill,
request_id,
decode_instance_endpoint,
@@ -241,7 +257,7 @@ async def handle_request():
selected_prefill_dp_rank,
)
)
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
ip, port = extract_ip_port_fast(prefill_request_url)
req_data["max_tokens"] -= 1
@@ -276,10 +292,9 @@ async def handle_request():
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
req_data["kv_transfer_params"]["transfer_id"] = transfer_id
decode_request_url = decode_instance_endpoint["request_address"] + api
decode_request_task = asyncio.create_task(
start_decode_request(
decode_instance_endpoint["request_address"], req_data, request_id
)
start_decode_request(decode_request_url, req_data, request_id)
)
session, decode_response = await decode_request_task