[ROCm] Profiler api support for ROCm MORI toy proxy server in PD Disaggregation (#40264)

Signed-off-by: Tej Kiran <kiran.tej@amd.com>
This commit is contained in:
tej
2026-05-07 03:58:38 -05:00
committed by GitHub
parent 713b28bd0b
commit 8a4888be21
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import asyncio
import copy
import logging
@@ -7,6 +8,7 @@ import os
import socket
import threading
import uuid
from urllib.parse import urlparse
import aiohttp
import msgpack
@@ -336,11 +338,81 @@ async def handle_request(api: str, request: Request):
)
async def send_profile_cmd(req_data: dict, profiler_cmd: str):
assert profiler_cmd in {"start", "stop"}
with _list_lock:
p_instances = list(prefill_instances)
d_instances = list(decode_instances)
if not p_instances and not d_instances:
raise RuntimeError(
"Service Unavailable: No prefill or decode instances are registered."
)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
tasks = []
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=60)
) as session:
for instances in (p_instances, d_instances):
for inst in instances:
_p = urlparse(inst["request_address"])
url = f"http://{_p.hostname}:{_p.port}/{profiler_cmd}_profile"
tasks.append(
session.post(
url,
json=req_data,
headers=headers,
)
)
responses = await asyncio.gather(*tasks, return_exceptions=True)
for r in responses:
if isinstance(r, Exception):
raise r
if r.status >= 400:
msg = await r.text()
raise RuntimeError(f"{profiler_cmd}_profile failed: {r.status}, {msg}")
return await responses[0].json()
@app.post("/start_profile")
async def start_profile():
try:
req_data = await request.get_json()
return await send_profile_cmd(req_data, "start")
except Exception as e:
logger.exception("start_profile failed: %s", e)
return await make_response((str(e), 500))
@app.post("/stop_profile")
async def stop_profile():
try:
req_data = await request.get_json()
return await send_profile_cmd(req_data, "stop")
except Exception as e:
logger.exception("stop_profile failed: %s", e)
return await make_response((str(e), 500))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=10001)
args = parser.parse_args()
t = start_service_discovery("0.0.0.0", 36367)
app.debug = True
app.config["BODY_TIMEOUT"] = 360000
app.config["RESPONSE_TIMEOUT"] = 360000
app.run(host="0.0.0.0", port=10001)
app.run(host="0.0.0.0", port=args.port)
t.join()