diff --git a/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py index e2a0bfc7c9b..de4757f36b7 100644 --- a/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py +++ b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py @@ -10,7 +10,6 @@ import uuid import aiohttp import msgpack -import regex as re import zmq from quart import Quart, Request, make_response, request @@ -25,32 +24,10 @@ decode_instances: list[dict] = [] request_nums = 0 app = Quart(__name__) -IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)") - TRANSFER_TYPE = None -def _append_whole_dict_unique(target_list, data_dict): - new_filtered = {k: v for k, v in data_dict.items() if k != "index"} - for existed in target_list: - existed_filtered = {k: v for k, v in existed.items() if k != "index"} - if existed_filtered == new_filtered: - return False - print("!!APPEND!!", data_dict) - target_list.append(data_dict) - transfer_mode = data_dict.get("transfer_mode", "unknown") - global TRANSFER_TYPE - - if TRANSFER_TYPE is None: - TRANSFER_TYPE = transfer_mode - logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE) - elif transfer_mode != TRANSFER_TYPE: - raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}") - - return True - - _list_lock = threading.RLock() @@ -68,23 +45,81 @@ def _listen_for_register(hostname, port): if router_socket in socks: remote_addr, msg = router_socket.recv_multipart() data = msgpack.loads(msg) - if data["type"] == "HELLO": + if data.get("type") == "HELLO": pass - elif ( - data["type"] == "register" - and data["role"] == "P" - and data["request_address"] not in prefill_instances - ): - with _list_lock: - _append_whole_dict_unique(prefill_instances, data) + elif data.get("type") in ("P", "D"): + role = data["type"] + required_keys = { + "http_address", + "zmq_address", + "dp_size", + "tp_size", + "transfer_mode", + } + missing = required_keys - data.keys() + if missing: + logger.error( + "Registration message missing required keys %s; skipping", + missing, + ) + continue + # Derive request_address from http_address + # api path suffix is appended at request time + instance = { + "role": role, + "request_address": f"http://{data['http_address']}/v1", + "http_address": data["http_address"], + "zmq_address": data["zmq_address"], + "dp_size": data["dp_size"], + "tp_size": data["tp_size"], + "transfer_mode": data["transfer_mode"], + } + # zmq_address format: "host:IP,handshake:PORT,notify:PORT" + # Stored verbatim; embedded into the request_id by handle_request. - elif ( - data["type"] == "register" - and data["role"] == "D" - and data["request_address"] not in decode_instances - ): + global TRANSFER_TYPE + transfer_mode = instance["transfer_mode"] + target_list = prefill_instances if role == "P" else decode_instances with _list_lock: - _append_whole_dict_unique(decode_instances, data) + if TRANSFER_TYPE is None: + TRANSFER_TYPE = transfer_mode + logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE) + elif transfer_mode != TRANSFER_TYPE: + logger.error( + "Mismatched transfer mode: expected %s, got %s;" + " skipping registration of %s", + TRANSFER_TYPE, + transfer_mode, + data["http_address"], + ) + continue + existing_idx = next( + ( + idx + for idx, i in enumerate(target_list) + if i.get("http_address") == data["http_address"] + ), + None, + ) + if existing_idx is not None: + target_list[existing_idx] = instance + logger.info( + "Updated existing %s instance: %s", + "Prefill" if role == "P" else "Decode", + instance, + ) + else: + target_list.append(instance) + logger.info( + "Registered %s instance: %s", + "Prefill" if role == "P" else "Decode", + instance, + ) + else: + logger.warning( + "Received message with unrecognized type %r; ignoring", + data.get("type"), + ) def start_service_discovery(hostname, port): @@ -101,7 +136,7 @@ def start_service_discovery(hostname, port): async def send_request_to_prefill( - endpoint, req_data, request_id, d_endpoint, dip, dport, selected_prefill_dp_rank + endpoint, req_data, request_id, selected_prefill_dp_rank ): req_data_copy = req_data @@ -109,12 +144,8 @@ async def send_request_to_prefill( { "do_remote_decode": True, "do_remote_prefill": False, - "remote_handshake_port": d_endpoint["handshake_port"], - "remote_notify_port": d_endpoint["notify_port"], "remote_engine_id": None, "remote_block_ids": None, - "remote_host": dip, - "remote_port": dport, } ) req_data_copy["stream"] = False @@ -197,14 +228,7 @@ async def handle_request(api: str, request: Request): global request_nums request_nums += 1 - def extract_ip_port_fast(url): - match = IP_PORT_PATTERN.search(url) - if not match: - raise ValueError(f"Invalid URL format: {url}") - return match.groups() - req_data = await request.get_json() - request_id = str(uuid.uuid4()) prefill_instance_endpoint = None decode_instance_endpoint = None @@ -230,7 +254,14 @@ async def handle_request(api: str, request: Request): prefill_instance_endpoint["dp_size"], ) - dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"]) + # Embed both zmq_addresses in the request_id so the connector can parse + # the peer's host/ports from it, similar to P2P-NCCL + uid = str(uuid.uuid4()).replace("-", "") + request_id = ( + f"___prefill_addr_{prefill_instance_endpoint['zmq_address']}" + f"___decode_addr_{decode_instance_endpoint['zmq_address']}" + f"_{uid}" + ) transfer_id = f"{MoRIIOConstants.TRANSFER_PREFIX}-{str(uuid.uuid4())}" @@ -251,35 +282,30 @@ async def handle_request(api: str, request: Request): prefill_request_url, req_data_to_prefill, request_id, - decode_instance_endpoint, - dip, - dport, selected_prefill_dp_rank, ) ) - ip, port = extract_ip_port_fast(prefill_request_url) req_data["max_tokens"] -= 1 req_data["kv_transfer_params"] = { "do_remote_decode": False, "do_remote_prefill": True, - "remote_handshake_port": prefill_instance_endpoint["handshake_port"], - "remote_notify_port": prefill_instance_endpoint["notify_port"], "remote_engine_id": None, "remote_block_ids": None, - "remote_host": ip, - "remote_port": port, + "transfer_id": transfer_id, } if TRANSFER_TYPE == "READ": # In read mode, prefill and decode are executed serially. prefill_response = await send_prefill_task - req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[ - "kv_transfer_params" - ]["remote_engine_id"] - req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[ - "kv_transfer_params" - ]["remote_block_ids"] + prefill_kv = prefill_response["kv_transfer_params"] + req_data["kv_transfer_params"]["remote_engine_id"] = prefill_kv[ + "remote_engine_id" + ] + req_data["kv_transfer_params"]["remote_block_ids"] = prefill_kv[ + "remote_block_ids" + ] + req_data["kv_transfer_params"]["transfer_id"] = prefill_kv["transfer_id"] req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[ "dp_size" @@ -290,7 +316,6 @@ async def handle_request(api: str, request: Request): if selected_prefill_dp_rank is not None: req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank - req_data["kv_transfer_params"]["transfer_id"] = transfer_id decode_request_url = decode_instance_endpoint["request_address"] + api decode_request_task = asyncio.create_task( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py index f3b2ce3b5be..b843c5b5930 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any import msgspec +import regex as re import torch import zmq @@ -239,7 +240,7 @@ class MoRIIOConstants: COMPLETION_PREFIX = "cmpl" TRANSFER_PREFIX = "tx" - PING_INTERVAL = 5 + PING_INTERVAL = 3 MAX_PING_RETRIES = 100 DEFAULT_HANDSHAKE_PORT = "6301" DEFAULT_NOTIFY_PORT = "61005" @@ -247,6 +248,64 @@ class MoRIIOConstants: VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600 +# The router embeds both zmq_addresses in the request_id (similar to P2pNcclConnector): +# "___prefill_addr_{zmq}___decode_addr_{zmq}_{32-hex-uuid}" +# MoRIIO zmq_address format: "host:IP,handshake:PORT,notify:PORT" +# +# This lets each connector side parse the peer's connection info without +# requiring the router to pass it explicitly in kv_transfer_params. +_PREFILL_ZMQ_RE = re.compile(r"___prefill_addr_(.+?)___decode_addr_") +# vLLM wraps the router's X-Request-Id as "cmpl---" so there may +# be a trailing "--" suffix after the 32-char UUID. Allow it. +_DECODE_ZMQ_RE = re.compile(r"___decode_addr_(.+)_[0-9a-f]{32}(?:-.*)?$") + + +def parse_moriio_zmq_address( + zmq_address: str, +) -> tuple[str, int, int]: + """Parse the MoRI-IO zmq address into its components. + + Parses ``"host:IP,handshake:PORT,notify:PORT"`` into + (host, handshake_port, notify_port). + + Each key-value pair is split on the *first* colon so that IPv6 addresses + (e.g. ``host:::1``) are handled correctly. Raises ``ValueError`` if any + of ``host``, ``handshake``, or ``notify`` keys are absent or if the port + values are non-numeric. + """ + parts: dict[str, str] = {} + for segment in zmq_address.split(","): + key, _, val = segment.partition(":") + parts[key.strip()] = val.strip() + try: + host = parts["host"] + handshake_port = int(parts["handshake"]) + notify_port = int(parts["notify"]) + except (KeyError, ValueError) as e: + raise ValueError( + f"Malformed zmq_address {zmq_address!r}: expected " + f"'host:IP,handshake:PORT,notify:PORT' format" + ) from e + return host, handshake_port, notify_port + + +def get_peer_zmq_from_request_id(request_id: str, is_producer: bool) -> str: + """Extract the *peer's* zmq_address from the vLLM router request_id. + + The producer (prefill) needs the decode's address; the consumer (decode) + needs the prefill's address. + """ + if is_producer: + m = _DECODE_ZMQ_RE.search(request_id) + else: + m = _PREFILL_ZMQ_RE.search(request_id) + if m is None: + raise ValueError( + f"Cannot parse peer zmq_address from request_id: {request_id!r}" + ) + return m.group(1) + + @dataclass class ReqMeta: """Metadata for a single request.""" @@ -286,15 +345,23 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): write_mode=False, ): transfer_id = kv_transfer_params["transfer_id"] + + # Parse host/ports from the request_id. The router embeds both zmq_addresses + # in the request_id + peer_zmq = get_peer_zmq_from_request_id(request_id, is_producer=write_mode) + remote_host, remote_handshake_port, remote_notify_port = ( + parse_moriio_zmq_address(peer_zmq) + ) + _req = ReqMeta( transfer_id=transfer_id, local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], - remote_host=kv_transfer_params["remote_host"], - remote_port=kv_transfer_params["remote_port"], - remote_handshake_port=kv_transfer_params["remote_handshake_port"], - remote_notify_port=kv_transfer_params["remote_notify_port"], + remote_host=remote_host, + remote_port=remote_handshake_port, + remote_handshake_port=remote_handshake_port, + remote_notify_port=remote_notify_port, tp_size=kv_transfer_params.get("tp_size", 1), remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index 0fd6d81f23e..15aca3e571c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -35,8 +35,10 @@ from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( TransferId, WriteTask, get_moriio_mode, + get_peer_zmq_from_request_id, get_port_offset, get_role, + parse_moriio_zmq_address, set_role, zmq_ctx, ) @@ -379,13 +381,12 @@ class MoRIIOConnectorScheduler: if params is not None and params.get("do_remote_prefill"): if self.mode == MoRIIOMode.READ: if remote_block_ids := params.get("remote_block_ids"): - if all( - p in params - for p in ("remote_engine_id", "remote_host", "remote_port") - ): - # If remote_blocks and num_external_tokens = 0, we + # remote_engine_id is returned by the prefill's request_finished. + # host/ports come from the request_id (parsed in add_new_req). + if "remote_engine_id" in params: + # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call - # send_notif in _read_blocks to free the memory on the P. + # send_notify in _read_blocks to free the memory on the P. # Get unhashed blocks to pull from remote. local_block_ids = blocks.get_block_ids()[0] @@ -407,22 +408,30 @@ class MoRIIOConnectorScheduler: ) else: + # WRITE mode: prefill scheduler notifies the decode side that + # blocks are ready. Parse the decode's host/notify_port from + # the request_id assert request.kv_transfer_params is not None, ( "kv_transfer_params should not be None" ) remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0) + peer_zmq = get_peer_zmq_from_request_id( + request.request_id, is_producer=True + ) + remote_host, _, remote_notify_port = parse_moriio_zmq_address(peer_zmq) + for tp_index in range(self.tp_size): - target_port = request.kv_transfer_params[ - "remote_notify_port" - ] + get_port_offset(remote_dp_rank, tp_index) + target_port = remote_notify_port + get_port_offset( + remote_dp_rank, tp_index + ) self.send_notify_block( req_id=request.request_id, transfer_id=request.kv_transfer_params["transfer_id"], block_notify_list=blocks.get_block_ids()[0], - host=params.get("remote_host"), + host=remote_host, port=target_port, ) @@ -584,15 +593,15 @@ class MoRIIOConnectorScheduler: + MoRIIOConstants.VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT ) - # If we execute in P-D serial mode, no notification port is needed. + # Return KV transfer params forwarded verbatim to the decode instance by + # the router. return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, remote_block_ids=computed_block_ids, remote_engine_id=self.engine_id, - remote_host=self.host_ip, - remote_port=self.handshake_port, tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + transfer_id=params["transfer_id"], ) @@ -846,7 +855,15 @@ class MoRIIOConnectorWorker: ] def _ping(self, zmq_context): - http_request_address = f"http://{self.request_address}/v1" + # Use host:port format for http_address (compatible with official router) + http_address = f"{self.request_address}" + # Include host so the router embeds it in the request_id; the connector + # on the other side parses host/ports from there. + zmq_address = ( + f"host:{self.local_ip}," + f"handshake:{self.handshake_port}," + f"notify:{self.notify_port}" + ) role = "P" if self.is_producer else "D" retry_count = 0 @@ -857,14 +874,17 @@ class MoRIIOConnectorWorker: while True: try: data = { - "type": "register", - "role": role, - "index": str(index), - "request_address": http_request_address, - "handshake_port": self.handshake_port, - "notify_port": self.notify_port, + "type": role, # "P" or "D" + "http_address": http_address, + "zmq_address": zmq_address, + # dp_size/tp_size are not used by the official vLLM router + # (routing operates at the http_address level); they are + # consumed only by the toy proxy server. "dp_size": self.moriio_config.dp_size, "tp_size": self.moriio_config.tp_size, + # transfer_mode is included so the router can distinguish + # READ (prefill-then-decode, sequential) from WRITE (concurrent) + # scheduling. "transfer_mode": self.mode.name, }