From 8a4888be21219bf081b5bc7e01f8ecf2af9057a8 Mon Sep 17 00:00:00 2001 From: tej <37236721+itej89@users.noreply.github.com> Date: Thu, 7 May 2026 03:58:38 -0500 Subject: [PATCH] [ROCm] Profiler api support for ROCm MORI toy proxy server in PD Disaggregation (#40264) Signed-off-by: Tej Kiran --- .../moriio_toy_proxy_server.py | 74 ++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/examples/disaggregated/disaggregated_serving/moriio_toy_proxy_server.py b/examples/disaggregated/disaggregated_serving/moriio_toy_proxy_server.py index de4757f36b7..aceb7a9b81c 100644 --- a/examples/disaggregated/disaggregated_serving/moriio_toy_proxy_server.py +++ b/examples/disaggregated/disaggregated_serving/moriio_toy_proxy_server.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse import asyncio import copy import logging @@ -7,6 +8,7 @@ import os import socket import threading import uuid +from urllib.parse import urlparse import aiohttp import msgpack @@ -336,11 +338,81 @@ async def handle_request(api: str, request: Request): ) +async def send_profile_cmd(req_data: dict, profiler_cmd: str): + assert profiler_cmd in {"start", "stop"} + + with _list_lock: + p_instances = list(prefill_instances) + d_instances = list(decode_instances) + + if not p_instances and not d_instances: + raise RuntimeError( + "Service Unavailable: No prefill or decode instances are registered." + ) + + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + tasks = [] + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=60) + ) as session: + for instances in (p_instances, d_instances): + for inst in instances: + _p = urlparse(inst["request_address"]) + url = f"http://{_p.hostname}:{_p.port}/{profiler_cmd}_profile" + + tasks.append( + session.post( + url, + json=req_data, + headers=headers, + ) + ) + + responses = await asyncio.gather(*tasks, return_exceptions=True) + + for r in responses: + if isinstance(r, Exception): + raise r + if r.status >= 400: + msg = await r.text() + raise RuntimeError(f"{profiler_cmd}_profile failed: {r.status}, {msg}") + + return await responses[0].json() + + +@app.post("/start_profile") +async def start_profile(): + try: + req_data = await request.get_json() + return await send_profile_cmd(req_data, "start") + except Exception as e: + logger.exception("start_profile failed: %s", e) + return await make_response((str(e), 500)) + + +@app.post("/stop_profile") +async def stop_profile(): + try: + req_data = await request.get_json() + return await send_profile_cmd(req_data, "stop") + except Exception as e: + logger.exception("stop_profile failed: %s", e) + return await make_response((str(e), 500)) + + if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=10001) + args = parser.parse_args() + t = start_service_discovery("0.0.0.0", 36367) app.debug = True app.config["BODY_TIMEOUT"] = 360000 app.config["RESPONSE_TIMEOUT"] = 360000 - app.run(host="0.0.0.0", port=10001) + app.run(host="0.0.0.0", port=args.port) t.join()