mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm][P/D][MORI][BugFix] Ensure correct api is used when making requests to prefill / decode nodes (#39835)
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
This commit is contained in:
@@ -12,7 +12,7 @@ import aiohttp
|
||||
import msgpack
|
||||
import regex as re
|
||||
import zmq
|
||||
from quart import Quart, make_response, request
|
||||
from quart import Quart, Request, make_response, request
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
|
||||
MoRIIOConstants,
|
||||
@@ -139,10 +139,13 @@ async def send_request_to_prefill(
|
||||
return await response.json()
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"send_request_to_prefill response.status != 200response.status = ",
|
||||
response.status,
|
||||
error_message = (
|
||||
f"send_request_to_prefill response ={response},"
|
||||
f"reason={response.reason}, status={response.status},"
|
||||
f"method={response.method}, url={response.url},"
|
||||
f"real_url={response.real_url}"
|
||||
)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
||||
async def start_decode_request(endpoint, req_data, request_id):
|
||||
@@ -163,9 +166,13 @@ async def stream_decode_response(session, response, request_id):
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"decode response.status != 200, status = {response.status}"
|
||||
error_message = (
|
||||
f"stream_decode_response response ={response},"
|
||||
f"reason={response.reason}, status={response.status},"
|
||||
f"method={response.method}, url={response.url},"
|
||||
f"real_url={response.real_url}"
|
||||
)
|
||||
raise RuntimeError(error_message)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@@ -175,8 +182,16 @@ def example_round_robin_dp_loader(request_number, dp_size):
|
||||
|
||||
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
async def handle_completions_request():
|
||||
return await handle_request("/completions", request)
|
||||
|
||||
|
||||
@app.route("/v1/chat/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
async def handle_chat_completions_request():
|
||||
return await handle_request("/chat/completions", request)
|
||||
|
||||
|
||||
async def handle_request(api: str, request: Request):
|
||||
try:
|
||||
with _list_lock:
|
||||
global request_nums
|
||||
@@ -230,9 +245,10 @@ async def handle_request():
|
||||
)
|
||||
req_data_to_prefill["kv_transfer_params"]["transfer_id"] = transfer_id
|
||||
|
||||
prefill_request_url = prefill_instance_endpoint["request_address"] + api
|
||||
send_prefill_task = asyncio.create_task(
|
||||
send_request_to_prefill(
|
||||
prefill_instance_endpoint["request_address"],
|
||||
prefill_request_url,
|
||||
req_data_to_prefill,
|
||||
request_id,
|
||||
decode_instance_endpoint,
|
||||
@@ -241,7 +257,7 @@ async def handle_request():
|
||||
selected_prefill_dp_rank,
|
||||
)
|
||||
)
|
||||
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
|
||||
ip, port = extract_ip_port_fast(prefill_request_url)
|
||||
|
||||
req_data["max_tokens"] -= 1
|
||||
|
||||
@@ -276,10 +292,9 @@ async def handle_request():
|
||||
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
|
||||
req_data["kv_transfer_params"]["transfer_id"] = transfer_id
|
||||
|
||||
decode_request_url = decode_instance_endpoint["request_address"] + api
|
||||
decode_request_task = asyncio.create_task(
|
||||
start_decode_request(
|
||||
decode_instance_endpoint["request_address"], req_data, request_id
|
||||
)
|
||||
start_decode_request(decode_request_url, req_data, request_id)
|
||||
)
|
||||
|
||||
session, decode_response = await decode_request_task
|
||||
|
||||
Reference in New Issue
Block a user