mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm] Profiler api support for ROCm MORI toy proxy server in PD Disaggregation (#40264)
Signed-off-by: Tej Kiran <kiran.tej@amd.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
@@ -7,6 +8,7 @@ import os
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import msgpack
|
||||
@@ -336,11 +338,81 @@ async def handle_request(api: str, request: Request):
|
||||
)
|
||||
|
||||
|
||||
async def send_profile_cmd(req_data: dict, profiler_cmd: str):
|
||||
assert profiler_cmd in {"start", "stop"}
|
||||
|
||||
with _list_lock:
|
||||
p_instances = list(prefill_instances)
|
||||
d_instances = list(decode_instances)
|
||||
|
||||
if not p_instances and not d_instances:
|
||||
raise RuntimeError(
|
||||
"Service Unavailable: No prefill or decode instances are registered."
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
tasks = []
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=60)
|
||||
) as session:
|
||||
for instances in (p_instances, d_instances):
|
||||
for inst in instances:
|
||||
_p = urlparse(inst["request_address"])
|
||||
url = f"http://{_p.hostname}:{_p.port}/{profiler_cmd}_profile"
|
||||
|
||||
tasks.append(
|
||||
session.post(
|
||||
url,
|
||||
json=req_data,
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for r in responses:
|
||||
if isinstance(r, Exception):
|
||||
raise r
|
||||
if r.status >= 400:
|
||||
msg = await r.text()
|
||||
raise RuntimeError(f"{profiler_cmd}_profile failed: {r.status}, {msg}")
|
||||
|
||||
return await responses[0].json()
|
||||
|
||||
|
||||
@app.post("/start_profile")
|
||||
async def start_profile():
|
||||
try:
|
||||
req_data = await request.get_json()
|
||||
return await send_profile_cmd(req_data, "start")
|
||||
except Exception as e:
|
||||
logger.exception("start_profile failed: %s", e)
|
||||
return await make_response((str(e), 500))
|
||||
|
||||
|
||||
@app.post("/stop_profile")
|
||||
async def stop_profile():
|
||||
try:
|
||||
req_data = await request.get_json()
|
||||
return await send_profile_cmd(req_data, "stop")
|
||||
except Exception as e:
|
||||
logger.exception("stop_profile failed: %s", e)
|
||||
return await make_response((str(e), 500))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=10001)
|
||||
args = parser.parse_args()
|
||||
|
||||
t = start_service_discovery("0.0.0.0", 36367)
|
||||
app.debug = True
|
||||
app.config["BODY_TIMEOUT"] = 360000
|
||||
app.config["RESPONSE_TIMEOUT"] = 360000
|
||||
|
||||
app.run(host="0.0.0.0", port=10001)
|
||||
app.run(host="0.0.0.0", port=args.port)
|
||||
t.join()
|
||||
|
||||
Reference in New Issue
Block a user