[Fix][MoRI] Align MoRI-IO message format with P2pNcclConnector and vllm-router (#39565)

Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Co-authored-by: Matvei Pashkovskii <mpashkov@amd.com>
This commit is contained in:
Simon Danielsson
2026-04-23 01:06:31 +02:00
committed by GitHub
parent b8401a9bf4
commit ac58e2a170
3 changed files with 202 additions and 90 deletions
@@ -10,7 +10,6 @@ import uuid
import aiohttp
import msgpack
import regex as re
import zmq
from quart import Quart, Request, make_response, request
@@ -25,32 +24,10 @@ decode_instances: list[dict] = []
request_nums = 0
app = Quart(__name__)
IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)")
TRANSFER_TYPE = None
def _append_whole_dict_unique(target_list, data_dict):
new_filtered = {k: v for k, v in data_dict.items() if k != "index"}
for existed in target_list:
existed_filtered = {k: v for k, v in existed.items() if k != "index"}
if existed_filtered == new_filtered:
return False
print("!!APPEND!!", data_dict)
target_list.append(data_dict)
transfer_mode = data_dict.get("transfer_mode", "unknown")
global TRANSFER_TYPE
if TRANSFER_TYPE is None:
TRANSFER_TYPE = transfer_mode
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
elif transfer_mode != TRANSFER_TYPE:
raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}")
return True
_list_lock = threading.RLock()
@@ -68,23 +45,81 @@ def _listen_for_register(hostname, port):
if router_socket in socks:
remote_addr, msg = router_socket.recv_multipart()
data = msgpack.loads(msg)
if data["type"] == "HELLO":
if data.get("type") == "HELLO":
pass
elif (
data["type"] == "register"
and data["role"] == "P"
and data["request_address"] not in prefill_instances
):
with _list_lock:
_append_whole_dict_unique(prefill_instances, data)
elif data.get("type") in ("P", "D"):
role = data["type"]
required_keys = {
"http_address",
"zmq_address",
"dp_size",
"tp_size",
"transfer_mode",
}
missing = required_keys - data.keys()
if missing:
logger.error(
"Registration message missing required keys %s; skipping",
missing,
)
continue
# Derive request_address from http_address
# api path suffix is appended at request time
instance = {
"role": role,
"request_address": f"http://{data['http_address']}/v1",
"http_address": data["http_address"],
"zmq_address": data["zmq_address"],
"dp_size": data["dp_size"],
"tp_size": data["tp_size"],
"transfer_mode": data["transfer_mode"],
}
# zmq_address format: "host:IP,handshake:PORT,notify:PORT"
# Stored verbatim; embedded into the request_id by handle_request.
elif (
data["type"] == "register"
and data["role"] == "D"
and data["request_address"] not in decode_instances
):
global TRANSFER_TYPE
transfer_mode = instance["transfer_mode"]
target_list = prefill_instances if role == "P" else decode_instances
with _list_lock:
_append_whole_dict_unique(decode_instances, data)
if TRANSFER_TYPE is None:
TRANSFER_TYPE = transfer_mode
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
elif transfer_mode != TRANSFER_TYPE:
logger.error(
"Mismatched transfer mode: expected %s, got %s;"
" skipping registration of %s",
TRANSFER_TYPE,
transfer_mode,
data["http_address"],
)
continue
existing_idx = next(
(
idx
for idx, i in enumerate(target_list)
if i.get("http_address") == data["http_address"]
),
None,
)
if existing_idx is not None:
target_list[existing_idx] = instance
logger.info(
"Updated existing %s instance: %s",
"Prefill" if role == "P" else "Decode",
instance,
)
else:
target_list.append(instance)
logger.info(
"Registered %s instance: %s",
"Prefill" if role == "P" else "Decode",
instance,
)
else:
logger.warning(
"Received message with unrecognized type %r; ignoring",
data.get("type"),
)
def start_service_discovery(hostname, port):
@@ -101,7 +136,7 @@ def start_service_discovery(hostname, port):
async def send_request_to_prefill(
endpoint, req_data, request_id, d_endpoint, dip, dport, selected_prefill_dp_rank
endpoint, req_data, request_id, selected_prefill_dp_rank
):
req_data_copy = req_data
@@ -109,12 +144,8 @@ async def send_request_to_prefill(
{
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_handshake_port": d_endpoint["handshake_port"],
"remote_notify_port": d_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": dip,
"remote_port": dport,
}
)
req_data_copy["stream"] = False
@@ -197,14 +228,7 @@ async def handle_request(api: str, request: Request):
global request_nums
request_nums += 1
def extract_ip_port_fast(url):
match = IP_PORT_PATTERN.search(url)
if not match:
raise ValueError(f"Invalid URL format: {url}")
return match.groups()
req_data = await request.get_json()
request_id = str(uuid.uuid4())
prefill_instance_endpoint = None
decode_instance_endpoint = None
@@ -230,7 +254,14 @@ async def handle_request(api: str, request: Request):
prefill_instance_endpoint["dp_size"],
)
dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"])
# Embed both zmq_addresses in the request_id so the connector can parse
# the peer's host/ports from it, similar to P2P-NCCL
uid = str(uuid.uuid4()).replace("-", "")
request_id = (
f"___prefill_addr_{prefill_instance_endpoint['zmq_address']}"
f"___decode_addr_{decode_instance_endpoint['zmq_address']}"
f"_{uid}"
)
transfer_id = f"{MoRIIOConstants.TRANSFER_PREFIX}-{str(uuid.uuid4())}"
@@ -251,35 +282,30 @@ async def handle_request(api: str, request: Request):
prefill_request_url,
req_data_to_prefill,
request_id,
decode_instance_endpoint,
dip,
dport,
selected_prefill_dp_rank,
)
)
ip, port = extract_ip_port_fast(prefill_request_url)
req_data["max_tokens"] -= 1
req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_handshake_port": prefill_instance_endpoint["handshake_port"],
"remote_notify_port": prefill_instance_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": ip,
"remote_port": port,
"transfer_id": transfer_id,
}
if TRANSFER_TYPE == "READ":
# In read mode, prefill and decode are executed serially.
prefill_response = await send_prefill_task
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[
"kv_transfer_params"
]["remote_engine_id"]
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[
"kv_transfer_params"
]["remote_block_ids"]
prefill_kv = prefill_response["kv_transfer_params"]
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_kv[
"remote_engine_id"
]
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_kv[
"remote_block_ids"
]
req_data["kv_transfer_params"]["transfer_id"] = prefill_kv["transfer_id"]
req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[
"dp_size"
@@ -290,7 +316,6 @@ async def handle_request(api: str, request: Request):
if selected_prefill_dp_rank is not None:
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
req_data["kv_transfer_params"]["transfer_id"] = transfer_id
decode_request_url = decode_instance_endpoint["request_address"] + api
decode_request_task = asyncio.create_task(
@@ -8,6 +8,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import msgspec
import regex as re
import torch
import zmq
@@ -239,7 +240,7 @@ class MoRIIOConstants:
COMPLETION_PREFIX = "cmpl"
TRANSFER_PREFIX = "tx"
PING_INTERVAL = 5
PING_INTERVAL = 3
MAX_PING_RETRIES = 100
DEFAULT_HANDSHAKE_PORT = "6301"
DEFAULT_NOTIFY_PORT = "61005"
@@ -247,6 +248,64 @@ class MoRIIOConstants:
VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600
# The router embeds both zmq_addresses in the request_id (similar to P2pNcclConnector):
# "___prefill_addr_{zmq}___decode_addr_{zmq}_{32-hex-uuid}"
# MoRIIO zmq_address format: "host:IP,handshake:PORT,notify:PORT"
#
# This lets each connector side parse the peer's connection info without
# requiring the router to pass it explicitly in kv_transfer_params.
_PREFILL_ZMQ_RE = re.compile(r"___prefill_addr_(.+?)___decode_addr_")
# vLLM wraps the router's X-Request-Id as "cmpl-<id>-<seq>-<hex>" so there may
# be a trailing "-<seq>-<hex>" suffix after the 32-char UUID. Allow it.
_DECODE_ZMQ_RE = re.compile(r"___decode_addr_(.+)_[0-9a-f]{32}(?:-.*)?$")
def parse_moriio_zmq_address(
zmq_address: str,
) -> tuple[str, int, int]:
"""Parse the MoRI-IO zmq address into its components.
Parses ``"host:IP,handshake:PORT,notify:PORT"`` into
(host, handshake_port, notify_port).
Each key-value pair is split on the *first* colon so that IPv6 addresses
(e.g. ``host:::1``) are handled correctly. Raises ``ValueError`` if any
of ``host``, ``handshake``, or ``notify`` keys are absent or if the port
values are non-numeric.
"""
parts: dict[str, str] = {}
for segment in zmq_address.split(","):
key, _, val = segment.partition(":")
parts[key.strip()] = val.strip()
try:
host = parts["host"]
handshake_port = int(parts["handshake"])
notify_port = int(parts["notify"])
except (KeyError, ValueError) as e:
raise ValueError(
f"Malformed zmq_address {zmq_address!r}: expected "
f"'host:IP,handshake:PORT,notify:PORT' format"
) from e
return host, handshake_port, notify_port
def get_peer_zmq_from_request_id(request_id: str, is_producer: bool) -> str:
"""Extract the *peer's* zmq_address from the vLLM router request_id.
The producer (prefill) needs the decode's address; the consumer (decode)
needs the prefill's address.
"""
if is_producer:
m = _DECODE_ZMQ_RE.search(request_id)
else:
m = _PREFILL_ZMQ_RE.search(request_id)
if m is None:
raise ValueError(
f"Cannot parse peer zmq_address from request_id: {request_id!r}"
)
return m.group(1)
@dataclass
class ReqMeta:
"""Metadata for a single request."""
@@ -286,15 +345,23 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata):
write_mode=False,
):
transfer_id = kv_transfer_params["transfer_id"]
# Parse host/ports from the request_id. The router embeds both zmq_addresses
# in the request_id
peer_zmq = get_peer_zmq_from_request_id(request_id, is_producer=write_mode)
remote_host, remote_handshake_port, remote_notify_port = (
parse_moriio_zmq_address(peer_zmq)
)
_req = ReqMeta(
transfer_id=transfer_id,
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
remote_notify_port=kv_transfer_params["remote_notify_port"],
remote_host=remote_host,
remote_port=remote_handshake_port,
remote_handshake_port=remote_handshake_port,
remote_notify_port=remote_notify_port,
tp_size=kv_transfer_params.get("tp_size", 1),
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
)
@@ -35,8 +35,10 @@ from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
TransferId,
WriteTask,
get_moriio_mode,
get_peer_zmq_from_request_id,
get_port_offset,
get_role,
parse_moriio_zmq_address,
set_role,
zmq_ctx,
)
@@ -379,13 +381,12 @@ class MoRIIOConnectorScheduler:
if params is not None and params.get("do_remote_prefill"):
if self.mode == MoRIIOMode.READ:
if remote_block_ids := params.get("remote_block_ids"):
if all(
p in params
for p in ("remote_engine_id", "remote_host", "remote_port")
):
# If remote_blocks and num_external_tokens = 0, we
# remote_engine_id is returned by the prefill's request_finished.
# host/ports come from the request_id (parsed in add_new_req).
if "remote_engine_id" in params:
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
# send_notify in _read_blocks to free the memory on the P.
# Get unhashed blocks to pull from remote.
local_block_ids = blocks.get_block_ids()[0]
@@ -407,22 +408,30 @@ class MoRIIOConnectorScheduler:
)
else:
# WRITE mode: prefill scheduler notifies the decode side that
# blocks are ready. Parse the decode's host/notify_port from
# the request_id
assert request.kv_transfer_params is not None, (
"kv_transfer_params should not be None"
)
remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0)
peer_zmq = get_peer_zmq_from_request_id(
request.request_id, is_producer=True
)
remote_host, _, remote_notify_port = parse_moriio_zmq_address(peer_zmq)
for tp_index in range(self.tp_size):
target_port = request.kv_transfer_params[
"remote_notify_port"
] + get_port_offset(remote_dp_rank, tp_index)
target_port = remote_notify_port + get_port_offset(
remote_dp_rank, tp_index
)
self.send_notify_block(
req_id=request.request_id,
transfer_id=request.kv_transfer_params["transfer_id"],
block_notify_list=blocks.get_block_ids()[0],
host=params.get("remote_host"),
host=remote_host,
port=target_port,
)
@@ -584,15 +593,15 @@ class MoRIIOConnectorScheduler:
+ MoRIIOConstants.VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT
)
# If we execute in P-D serial mode, no notification port is needed.
# Return KV transfer params forwarded verbatim to the decode instance by
# the router.
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.host_ip,
remote_port=self.handshake_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
transfer_id=params["transfer_id"],
)
@@ -846,7 +855,15 @@ class MoRIIOConnectorWorker:
]
def _ping(self, zmq_context):
http_request_address = f"http://{self.request_address}/v1"
# Use host:port format for http_address (compatible with official router)
http_address = f"{self.request_address}"
# Include host so the router embeds it in the request_id; the connector
# on the other side parses host/ports from there.
zmq_address = (
f"host:{self.local_ip},"
f"handshake:{self.handshake_port},"
f"notify:{self.notify_port}"
)
role = "P" if self.is_producer else "D"
retry_count = 0
@@ -857,14 +874,17 @@ class MoRIIOConnectorWorker:
while True:
try:
data = {
"type": "register",
"role": role,
"index": str(index),
"request_address": http_request_address,
"handshake_port": self.handshake_port,
"notify_port": self.notify_port,
"type": role, # "P" or "D"
"http_address": http_address,
"zmq_address": zmq_address,
# dp_size/tp_size are not used by the official vLLM router
# (routing operates at the http_address level); they are
# consumed only by the toy proxy server.
"dp_size": self.moriio_config.dp_size,
"tp_size": self.moriio_config.tp_size,
# transfer_mode is included so the router can distinguish
# READ (prefill-then-decode, sequential) from WRITE (concurrent)
# scheduling.
"transfer_mode": self.mode.name,
}