mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Fix][MoRI] Align MoRI-IO message format with P2pNcclConnector and vllm-router (#39565)
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Co-authored-by: Matvei Pashkovskii <mpashkov@amd.com>
This commit is contained in:
@@ -10,7 +10,6 @@ import uuid
|
||||
|
||||
import aiohttp
|
||||
import msgpack
|
||||
import regex as re
|
||||
import zmq
|
||||
from quart import Quart, Request, make_response, request
|
||||
|
||||
@@ -25,32 +24,10 @@ decode_instances: list[dict] = []
|
||||
request_nums = 0
|
||||
app = Quart(__name__)
|
||||
|
||||
IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)")
|
||||
|
||||
|
||||
TRANSFER_TYPE = None
|
||||
|
||||
|
||||
def _append_whole_dict_unique(target_list, data_dict):
|
||||
new_filtered = {k: v for k, v in data_dict.items() if k != "index"}
|
||||
for existed in target_list:
|
||||
existed_filtered = {k: v for k, v in existed.items() if k != "index"}
|
||||
if existed_filtered == new_filtered:
|
||||
return False
|
||||
print("!!APPEND!!", data_dict)
|
||||
target_list.append(data_dict)
|
||||
transfer_mode = data_dict.get("transfer_mode", "unknown")
|
||||
global TRANSFER_TYPE
|
||||
|
||||
if TRANSFER_TYPE is None:
|
||||
TRANSFER_TYPE = transfer_mode
|
||||
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
|
||||
elif transfer_mode != TRANSFER_TYPE:
|
||||
raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_list_lock = threading.RLock()
|
||||
|
||||
|
||||
@@ -68,23 +45,81 @@ def _listen_for_register(hostname, port):
|
||||
if router_socket in socks:
|
||||
remote_addr, msg = router_socket.recv_multipart()
|
||||
data = msgpack.loads(msg)
|
||||
if data["type"] == "HELLO":
|
||||
if data.get("type") == "HELLO":
|
||||
pass
|
||||
elif (
|
||||
data["type"] == "register"
|
||||
and data["role"] == "P"
|
||||
and data["request_address"] not in prefill_instances
|
||||
):
|
||||
with _list_lock:
|
||||
_append_whole_dict_unique(prefill_instances, data)
|
||||
elif data.get("type") in ("P", "D"):
|
||||
role = data["type"]
|
||||
required_keys = {
|
||||
"http_address",
|
||||
"zmq_address",
|
||||
"dp_size",
|
||||
"tp_size",
|
||||
"transfer_mode",
|
||||
}
|
||||
missing = required_keys - data.keys()
|
||||
if missing:
|
||||
logger.error(
|
||||
"Registration message missing required keys %s; skipping",
|
||||
missing,
|
||||
)
|
||||
continue
|
||||
# Derive request_address from http_address
|
||||
# api path suffix is appended at request time
|
||||
instance = {
|
||||
"role": role,
|
||||
"request_address": f"http://{data['http_address']}/v1",
|
||||
"http_address": data["http_address"],
|
||||
"zmq_address": data["zmq_address"],
|
||||
"dp_size": data["dp_size"],
|
||||
"tp_size": data["tp_size"],
|
||||
"transfer_mode": data["transfer_mode"],
|
||||
}
|
||||
# zmq_address format: "host:IP,handshake:PORT,notify:PORT"
|
||||
# Stored verbatim; embedded into the request_id by handle_request.
|
||||
|
||||
elif (
|
||||
data["type"] == "register"
|
||||
and data["role"] == "D"
|
||||
and data["request_address"] not in decode_instances
|
||||
):
|
||||
global TRANSFER_TYPE
|
||||
transfer_mode = instance["transfer_mode"]
|
||||
target_list = prefill_instances if role == "P" else decode_instances
|
||||
with _list_lock:
|
||||
_append_whole_dict_unique(decode_instances, data)
|
||||
if TRANSFER_TYPE is None:
|
||||
TRANSFER_TYPE = transfer_mode
|
||||
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
|
||||
elif transfer_mode != TRANSFER_TYPE:
|
||||
logger.error(
|
||||
"Mismatched transfer mode: expected %s, got %s;"
|
||||
" skipping registration of %s",
|
||||
TRANSFER_TYPE,
|
||||
transfer_mode,
|
||||
data["http_address"],
|
||||
)
|
||||
continue
|
||||
existing_idx = next(
|
||||
(
|
||||
idx
|
||||
for idx, i in enumerate(target_list)
|
||||
if i.get("http_address") == data["http_address"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if existing_idx is not None:
|
||||
target_list[existing_idx] = instance
|
||||
logger.info(
|
||||
"Updated existing %s instance: %s",
|
||||
"Prefill" if role == "P" else "Decode",
|
||||
instance,
|
||||
)
|
||||
else:
|
||||
target_list.append(instance)
|
||||
logger.info(
|
||||
"Registered %s instance: %s",
|
||||
"Prefill" if role == "P" else "Decode",
|
||||
instance,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Received message with unrecognized type %r; ignoring",
|
||||
data.get("type"),
|
||||
)
|
||||
|
||||
|
||||
def start_service_discovery(hostname, port):
|
||||
@@ -101,7 +136,7 @@ def start_service_discovery(hostname, port):
|
||||
|
||||
|
||||
async def send_request_to_prefill(
|
||||
endpoint, req_data, request_id, d_endpoint, dip, dport, selected_prefill_dp_rank
|
||||
endpoint, req_data, request_id, selected_prefill_dp_rank
|
||||
):
|
||||
req_data_copy = req_data
|
||||
|
||||
@@ -109,12 +144,8 @@ async def send_request_to_prefill(
|
||||
{
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_handshake_port": d_endpoint["handshake_port"],
|
||||
"remote_notify_port": d_endpoint["notify_port"],
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": dip,
|
||||
"remote_port": dport,
|
||||
}
|
||||
)
|
||||
req_data_copy["stream"] = False
|
||||
@@ -197,14 +228,7 @@ async def handle_request(api: str, request: Request):
|
||||
global request_nums
|
||||
request_nums += 1
|
||||
|
||||
def extract_ip_port_fast(url):
|
||||
match = IP_PORT_PATTERN.search(url)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid URL format: {url}")
|
||||
return match.groups()
|
||||
|
||||
req_data = await request.get_json()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
prefill_instance_endpoint = None
|
||||
decode_instance_endpoint = None
|
||||
@@ -230,7 +254,14 @@ async def handle_request(api: str, request: Request):
|
||||
prefill_instance_endpoint["dp_size"],
|
||||
)
|
||||
|
||||
dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"])
|
||||
# Embed both zmq_addresses in the request_id so the connector can parse
|
||||
# the peer's host/ports from it, similar to P2P-NCCL
|
||||
uid = str(uuid.uuid4()).replace("-", "")
|
||||
request_id = (
|
||||
f"___prefill_addr_{prefill_instance_endpoint['zmq_address']}"
|
||||
f"___decode_addr_{decode_instance_endpoint['zmq_address']}"
|
||||
f"_{uid}"
|
||||
)
|
||||
|
||||
transfer_id = f"{MoRIIOConstants.TRANSFER_PREFIX}-{str(uuid.uuid4())}"
|
||||
|
||||
@@ -251,35 +282,30 @@ async def handle_request(api: str, request: Request):
|
||||
prefill_request_url,
|
||||
req_data_to_prefill,
|
||||
request_id,
|
||||
decode_instance_endpoint,
|
||||
dip,
|
||||
dport,
|
||||
selected_prefill_dp_rank,
|
||||
)
|
||||
)
|
||||
ip, port = extract_ip_port_fast(prefill_request_url)
|
||||
|
||||
req_data["max_tokens"] -= 1
|
||||
|
||||
req_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"remote_handshake_port": prefill_instance_endpoint["handshake_port"],
|
||||
"remote_notify_port": prefill_instance_endpoint["notify_port"],
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": ip,
|
||||
"remote_port": port,
|
||||
"transfer_id": transfer_id,
|
||||
}
|
||||
if TRANSFER_TYPE == "READ":
|
||||
# In read mode, prefill and decode are executed serially.
|
||||
prefill_response = await send_prefill_task
|
||||
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[
|
||||
"kv_transfer_params"
|
||||
]["remote_engine_id"]
|
||||
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[
|
||||
"kv_transfer_params"
|
||||
]["remote_block_ids"]
|
||||
prefill_kv = prefill_response["kv_transfer_params"]
|
||||
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_kv[
|
||||
"remote_engine_id"
|
||||
]
|
||||
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_kv[
|
||||
"remote_block_ids"
|
||||
]
|
||||
req_data["kv_transfer_params"]["transfer_id"] = prefill_kv["transfer_id"]
|
||||
|
||||
req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[
|
||||
"dp_size"
|
||||
@@ -290,7 +316,6 @@ async def handle_request(api: str, request: Request):
|
||||
|
||||
if selected_prefill_dp_rank is not None:
|
||||
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
|
||||
req_data["kv_transfer_params"]["transfer_id"] = transfer_id
|
||||
|
||||
decode_request_url = decode_instance_endpoint["request_address"] + api
|
||||
decode_request_task = asyncio.create_task(
|
||||
|
||||
@@ -8,6 +8,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import msgspec
|
||||
import regex as re
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
@@ -239,7 +240,7 @@ class MoRIIOConstants:
|
||||
COMPLETION_PREFIX = "cmpl"
|
||||
TRANSFER_PREFIX = "tx"
|
||||
|
||||
PING_INTERVAL = 5
|
||||
PING_INTERVAL = 3
|
||||
MAX_PING_RETRIES = 100
|
||||
DEFAULT_HANDSHAKE_PORT = "6301"
|
||||
DEFAULT_NOTIFY_PORT = "61005"
|
||||
@@ -247,6 +248,64 @@ class MoRIIOConstants:
|
||||
VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600
|
||||
|
||||
|
||||
# The router embeds both zmq_addresses in the request_id (similar to P2pNcclConnector):
|
||||
# "___prefill_addr_{zmq}___decode_addr_{zmq}_{32-hex-uuid}"
|
||||
# MoRIIO zmq_address format: "host:IP,handshake:PORT,notify:PORT"
|
||||
#
|
||||
# This lets each connector side parse the peer's connection info without
|
||||
# requiring the router to pass it explicitly in kv_transfer_params.
|
||||
_PREFILL_ZMQ_RE = re.compile(r"___prefill_addr_(.+?)___decode_addr_")
|
||||
# vLLM wraps the router's X-Request-Id as "cmpl-<id>-<seq>-<hex>" so there may
|
||||
# be a trailing "-<seq>-<hex>" suffix after the 32-char UUID. Allow it.
|
||||
_DECODE_ZMQ_RE = re.compile(r"___decode_addr_(.+)_[0-9a-f]{32}(?:-.*)?$")
|
||||
|
||||
|
||||
def parse_moriio_zmq_address(
|
||||
zmq_address: str,
|
||||
) -> tuple[str, int, int]:
|
||||
"""Parse the MoRI-IO zmq address into its components.
|
||||
|
||||
Parses ``"host:IP,handshake:PORT,notify:PORT"`` into
|
||||
(host, handshake_port, notify_port).
|
||||
|
||||
Each key-value pair is split on the *first* colon so that IPv6 addresses
|
||||
(e.g. ``host:::1``) are handled correctly. Raises ``ValueError`` if any
|
||||
of ``host``, ``handshake``, or ``notify`` keys are absent or if the port
|
||||
values are non-numeric.
|
||||
"""
|
||||
parts: dict[str, str] = {}
|
||||
for segment in zmq_address.split(","):
|
||||
key, _, val = segment.partition(":")
|
||||
parts[key.strip()] = val.strip()
|
||||
try:
|
||||
host = parts["host"]
|
||||
handshake_port = int(parts["handshake"])
|
||||
notify_port = int(parts["notify"])
|
||||
except (KeyError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Malformed zmq_address {zmq_address!r}: expected "
|
||||
f"'host:IP,handshake:PORT,notify:PORT' format"
|
||||
) from e
|
||||
return host, handshake_port, notify_port
|
||||
|
||||
|
||||
def get_peer_zmq_from_request_id(request_id: str, is_producer: bool) -> str:
|
||||
"""Extract the *peer's* zmq_address from the vLLM router request_id.
|
||||
|
||||
The producer (prefill) needs the decode's address; the consumer (decode)
|
||||
needs the prefill's address.
|
||||
"""
|
||||
if is_producer:
|
||||
m = _DECODE_ZMQ_RE.search(request_id)
|
||||
else:
|
||||
m = _PREFILL_ZMQ_RE.search(request_id)
|
||||
if m is None:
|
||||
raise ValueError(
|
||||
f"Cannot parse peer zmq_address from request_id: {request_id!r}"
|
||||
)
|
||||
return m.group(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
"""Metadata for a single request."""
|
||||
@@ -286,15 +345,23 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata):
|
||||
write_mode=False,
|
||||
):
|
||||
transfer_id = kv_transfer_params["transfer_id"]
|
||||
|
||||
# Parse host/ports from the request_id. The router embeds both zmq_addresses
|
||||
# in the request_id
|
||||
peer_zmq = get_peer_zmq_from_request_id(request_id, is_producer=write_mode)
|
||||
remote_host, remote_handshake_port, remote_notify_port = (
|
||||
parse_moriio_zmq_address(peer_zmq)
|
||||
)
|
||||
|
||||
_req = ReqMeta(
|
||||
transfer_id=transfer_id,
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
|
||||
remote_notify_port=kv_transfer_params["remote_notify_port"],
|
||||
remote_host=remote_host,
|
||||
remote_port=remote_handshake_port,
|
||||
remote_handshake_port=remote_handshake_port,
|
||||
remote_notify_port=remote_notify_port,
|
||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
|
||||
)
|
||||
|
||||
@@ -35,8 +35,10 @@ from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
|
||||
TransferId,
|
||||
WriteTask,
|
||||
get_moriio_mode,
|
||||
get_peer_zmq_from_request_id,
|
||||
get_port_offset,
|
||||
get_role,
|
||||
parse_moriio_zmq_address,
|
||||
set_role,
|
||||
zmq_ctx,
|
||||
)
|
||||
@@ -379,13 +381,12 @@ class MoRIIOConnectorScheduler:
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
if self.mode == MoRIIOMode.READ:
|
||||
if remote_block_ids := params.get("remote_block_ids"):
|
||||
if all(
|
||||
p in params
|
||||
for p in ("remote_engine_id", "remote_host", "remote_port")
|
||||
):
|
||||
# If remote_blocks and num_external_tokens = 0, we
|
||||
# remote_engine_id is returned by the prefill's request_finished.
|
||||
# host/ports come from the request_id (parsed in add_new_req).
|
||||
if "remote_engine_id" in params:
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
# send_notify in _read_blocks to free the memory on the P.
|
||||
|
||||
# Get unhashed blocks to pull from remote.
|
||||
local_block_ids = blocks.get_block_ids()[0]
|
||||
@@ -407,22 +408,30 @@ class MoRIIOConnectorScheduler:
|
||||
)
|
||||
|
||||
else:
|
||||
# WRITE mode: prefill scheduler notifies the decode side that
|
||||
# blocks are ready. Parse the decode's host/notify_port from
|
||||
# the request_id
|
||||
assert request.kv_transfer_params is not None, (
|
||||
"kv_transfer_params should not be None"
|
||||
)
|
||||
|
||||
remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0)
|
||||
|
||||
peer_zmq = get_peer_zmq_from_request_id(
|
||||
request.request_id, is_producer=True
|
||||
)
|
||||
remote_host, _, remote_notify_port = parse_moriio_zmq_address(peer_zmq)
|
||||
|
||||
for tp_index in range(self.tp_size):
|
||||
target_port = request.kv_transfer_params[
|
||||
"remote_notify_port"
|
||||
] + get_port_offset(remote_dp_rank, tp_index)
|
||||
target_port = remote_notify_port + get_port_offset(
|
||||
remote_dp_rank, tp_index
|
||||
)
|
||||
|
||||
self.send_notify_block(
|
||||
req_id=request.request_id,
|
||||
transfer_id=request.kv_transfer_params["transfer_id"],
|
||||
block_notify_list=blocks.get_block_ids()[0],
|
||||
host=params.get("remote_host"),
|
||||
host=remote_host,
|
||||
port=target_port,
|
||||
)
|
||||
|
||||
@@ -584,15 +593,15 @@ class MoRIIOConnectorScheduler:
|
||||
+ MoRIIOConstants.VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT
|
||||
)
|
||||
|
||||
# If we execute in P-D serial mode, no notification port is needed.
|
||||
# Return KV transfer params forwarded verbatim to the decode instance by
|
||||
# the router.
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_block_ids=computed_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.host_ip,
|
||||
remote_port=self.handshake_port,
|
||||
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
|
||||
transfer_id=params["transfer_id"],
|
||||
)
|
||||
|
||||
|
||||
@@ -846,7 +855,15 @@ class MoRIIOConnectorWorker:
|
||||
]
|
||||
|
||||
def _ping(self, zmq_context):
|
||||
http_request_address = f"http://{self.request_address}/v1"
|
||||
# Use host:port format for http_address (compatible with official router)
|
||||
http_address = f"{self.request_address}"
|
||||
# Include host so the router embeds it in the request_id; the connector
|
||||
# on the other side parses host/ports from there.
|
||||
zmq_address = (
|
||||
f"host:{self.local_ip},"
|
||||
f"handshake:{self.handshake_port},"
|
||||
f"notify:{self.notify_port}"
|
||||
)
|
||||
role = "P" if self.is_producer else "D"
|
||||
|
||||
retry_count = 0
|
||||
@@ -857,14 +874,17 @@ class MoRIIOConnectorWorker:
|
||||
while True:
|
||||
try:
|
||||
data = {
|
||||
"type": "register",
|
||||
"role": role,
|
||||
"index": str(index),
|
||||
"request_address": http_request_address,
|
||||
"handshake_port": self.handshake_port,
|
||||
"notify_port": self.notify_port,
|
||||
"type": role, # "P" or "D"
|
||||
"http_address": http_address,
|
||||
"zmq_address": zmq_address,
|
||||
# dp_size/tp_size are not used by the official vLLM router
|
||||
# (routing operates at the http_address level); they are
|
||||
# consumed only by the toy proxy server.
|
||||
"dp_size": self.moriio_config.dp_size,
|
||||
"tp_size": self.moriio_config.tp_size,
|
||||
# transfer_mode is included so the router can distinguish
|
||||
# READ (prefill-then-decode, sequential) from WRITE (concurrent)
|
||||
# scheduling.
|
||||
"transfer_mode": self.mode.name,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user