mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[P/D] Prefill compute optimizations with bi-directional KV cache transfers between P and D nodes (#32553)
Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
This commit is contained in:
@@ -217,6 +217,7 @@ async def send_request(
|
||||
min_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout_sec: int = 120,
|
||||
conversation_id: str | None = None,
|
||||
) -> ServerResponse:
|
||||
payload = {
|
||||
"model": model,
|
||||
@@ -225,6 +226,9 @@ async def send_request(
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
if conversation_id is not None:
|
||||
payload["conversation_id"] = conversation_id
|
||||
|
||||
if stream:
|
||||
payload["stream"] = True
|
||||
payload["stream_options"] = {"include_usage": False}
|
||||
@@ -419,6 +423,7 @@ async def send_turn(
|
||||
min_tokens,
|
||||
max_tokens,
|
||||
req_args.timeout_sec,
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
if response.valid is False:
|
||||
|
||||
@@ -0,0 +1,562 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Disaggregated Prefill/Decode Proxy with Bidirectional KV Transfer
|
||||
|
||||
This proxy sits between clients and a vLLM Prefill/Decode (P/D) deployment,
|
||||
routing multi-turn chat requests so that each turn reuses KV cache blocks
|
||||
from the previous turn's Decode node via bidirectional KV transfer.
|
||||
|
||||
Architecture:
|
||||
Client ──► Proxy ──► Prefill (P) ──► Decode (D)
|
||||
│ │ │
|
||||
│ kv_transfer_params flow: │
|
||||
│ D finish ──► proxy caches │
|
||||
│ next turn ──► proxy sends │
|
||||
│ cached D blocks to P ──► │
|
||||
│ P reads D blocks (bidir) │
|
||||
│ P sends its blocks to D │
|
||||
|
||||
Per-request flow:
|
||||
1. Client sends chat/completions request to proxy.
|
||||
2. Proxy looks up cached D block info from the previous turn
|
||||
(keyed by conversation_id).
|
||||
3. If cache hit, proxy attaches D's block info to the request
|
||||
so P can read D's KV blocks instead of recomputing.
|
||||
4. Proxy sends request to P (max_tokens=1, non-streaming).
|
||||
5. P returns kv_transfer_params with its own block info.
|
||||
6. Proxy forwards request + P's block info to D (streaming).
|
||||
7. D streams the response. The final chunk includes D's
|
||||
kv_transfer_params, which the proxy caches for the next turn.
|
||||
8. Proxy returns D's response to the client.
|
||||
|
||||
Conversation isolation:
|
||||
Each request must include a ``conversation_id`` field (top-level in
|
||||
the JSON body) to scope the KV cache across turns. Without it, the
|
||||
proxy cannot link turns and falls back to no-cache behavior.
|
||||
|
||||
Usage:
|
||||
python disagg_proxy_multiturn.py \\
|
||||
--host 0.0.0.0 --port 8000 \\
|
||||
--prefiller-host 10.0.0.1 --prefiller-port 8100 \\
|
||||
--decoder-host 10.0.0.2 --decoder-port 8200
|
||||
|
||||
Dependencies:
|
||||
pip install fastapi uvicorn httpx
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
# Logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger("disagg_proxy")
|
||||
|
||||
|
||||
# Data structures
|
||||
@dataclass
|
||||
class CachedKVEntry:
|
||||
"""KV transfer parameters cached from D's response for one turn."""
|
||||
|
||||
kv_transfer_params: dict[str, Any]
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class ConversationKVCache:
|
||||
"""Per-conversation KV block cache.
|
||||
|
||||
Each conversation is identified by a ``conversation_id`` supplied by
|
||||
the client. After D finishes a turn, its ``kv_transfer_params`` are
|
||||
stored here. On the next turn, the proxy retrieves them so P can
|
||||
read D's blocks via bidirectional KV transfer.
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: float = 600.0) -> None:
|
||||
self._store: dict[str, CachedKVEntry] = {}
|
||||
self._ttl = ttl_seconds
|
||||
|
||||
def get(self, conversation_id: str) -> dict[str, Any] | None:
|
||||
"""Retrieve and consume cached KV params for a conversation.
|
||||
|
||||
Returns a *copy* of the kv_transfer_params dict, or None.
|
||||
The entry is removed after retrieval (single-use).
|
||||
"""
|
||||
entry = self._store.pop(conversation_id, None)
|
||||
if entry is None:
|
||||
return None
|
||||
age = time.time() - entry.timestamp
|
||||
if age > self._ttl:
|
||||
logger.info(
|
||||
"conv=%s: stale cache entry (age=%.1fs > ttl=%.1fs), discarding",
|
||||
conversation_id,
|
||||
age,
|
||||
self._ttl,
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"conv=%s: cache HIT (age=%.1fs)",
|
||||
conversation_id,
|
||||
age,
|
||||
)
|
||||
return dict(entry.kv_transfer_params)
|
||||
|
||||
def put(self, conversation_id: str, kv_params: dict[str, Any]) -> None:
|
||||
"""Store D's kv_transfer_params for a conversation."""
|
||||
self._store[conversation_id] = CachedKVEntry(
|
||||
kv_transfer_params=dict(kv_params), # defensive copy
|
||||
)
|
||||
logger.info(
|
||||
"conv=%s: cached D blocks (remote_request_id=%s, blocks=%d)",
|
||||
conversation_id,
|
||||
kv_params.get("remote_request_id", "?"),
|
||||
len(kv_params.get("remote_block_ids", [[]])[0])
|
||||
if kv_params.get("remote_block_ids")
|
||||
else 0,
|
||||
)
|
||||
|
||||
def evict_stale(self) -> int:
|
||||
"""Remove entries older than TTL. Returns count of evicted entries."""
|
||||
now = time.time()
|
||||
stale = [
|
||||
cid
|
||||
for cid, entry in self._store.items()
|
||||
if now - entry.timestamp > self._ttl
|
||||
]
|
||||
for cid in stale:
|
||||
del self._store[cid]
|
||||
return len(stale)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self._store)
|
||||
|
||||
|
||||
# Global state
|
||||
kv_cache = ConversationKVCache(
|
||||
ttl_seconds=450.0
|
||||
) # Must be < VLLM_NIXL_ABORT_REQUEST_TIMEOUT (480s)
|
||||
|
||||
|
||||
# Service client helpers
|
||||
@dataclass
|
||||
class ServiceClient:
|
||||
"""Wrapper around an httpx.AsyncClient for a P or D instance."""
|
||||
|
||||
client: httpx.AsyncClient
|
||||
host: str
|
||||
port: int
|
||||
id: int
|
||||
|
||||
|
||||
def _make_headers(request_id: str) -> dict[str, str]:
|
||||
"""Build HTTP headers for upstream requests."""
|
||||
headers = {"X-Request-Id": request_id}
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
async def _send_to_prefill(
|
||||
client: ServiceClient,
|
||||
endpoint: str,
|
||||
req_data: dict[str, Any],
|
||||
request_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a non-streaming prefill request (max_tokens=1).
|
||||
|
||||
Returns the JSON response from P, which includes kv_transfer_params.
|
||||
"""
|
||||
payload = req_data.copy()
|
||||
payload["stream"] = False
|
||||
payload["max_tokens"] = 1
|
||||
payload.pop("max_completion_tokens", None)
|
||||
payload.pop("min_tokens", None)
|
||||
payload.pop("stream_options", None)
|
||||
|
||||
resp = await client.client.post(
|
||||
endpoint,
|
||||
json=payload,
|
||||
headers=_make_headers(request_id),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _stream_from_decode(
|
||||
client: ServiceClient,
|
||||
endpoint: str,
|
||||
req_data: dict[str, Any],
|
||||
request_id: str,
|
||||
conversation_id: str,
|
||||
) -> tuple[str, str | None, dict[str, Any] | None, str, str | None, int | None]:
|
||||
"""Stream response from D, capturing text and kv_transfer_params.
|
||||
|
||||
Returns (collected_text, finish_reason, kv_params, response_id, created).
|
||||
Also stores kv_params in the conversation cache.
|
||||
"""
|
||||
payload = req_data.copy()
|
||||
payload["stream"] = True
|
||||
|
||||
collected_text = ""
|
||||
finish_reason: str | None = None
|
||||
response_id: str | None = None
|
||||
model_name: str | None = None
|
||||
created: int | None = None
|
||||
captured_kv: dict[str, Any] | None = None
|
||||
|
||||
async with client.client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=payload,
|
||||
headers=_make_headers(request_id),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
if line == "data: [DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if response_id is None:
|
||||
response_id = chunk.get("id")
|
||||
model_name = chunk.get("model")
|
||||
created = chunk.get("created")
|
||||
|
||||
for choice in chunk.get("choices", []):
|
||||
collected_text += choice.get("text", "")
|
||||
delta = choice.get("delta", {})
|
||||
collected_text += delta.get("content", "")
|
||||
if choice.get("finish_reason"):
|
||||
finish_reason = choice["finish_reason"]
|
||||
|
||||
kv_params = chunk.get("kv_transfer_params")
|
||||
if kv_params:
|
||||
kv_params["remote_host"] = client.host
|
||||
captured_kv = kv_params
|
||||
if conversation_id:
|
||||
kv_cache.put(conversation_id, kv_params)
|
||||
|
||||
return (
|
||||
collected_text,
|
||||
finish_reason,
|
||||
captured_kv,
|
||||
response_id or request_id,
|
||||
model_name,
|
||||
created,
|
||||
)
|
||||
|
||||
|
||||
async def _stream_from_decode_sse(
|
||||
client: ServiceClient,
|
||||
endpoint: str,
|
||||
req_data: dict[str, Any],
|
||||
request_id: str,
|
||||
conversation_id: str,
|
||||
):
|
||||
"""Yield SSE chunks from D to the client, capturing kv_transfer_params."""
|
||||
payload = req_data.copy()
|
||||
payload["stream"] = True
|
||||
|
||||
async with client.client.stream(
|
||||
"POST",
|
||||
endpoint,
|
||||
json=payload,
|
||||
headers=_make_headers(request_id),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line:
|
||||
yield "\n"
|
||||
continue
|
||||
|
||||
if line.startswith("data: ") and line != "data: [DONE]":
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
kv_params = chunk.get("kv_transfer_params")
|
||||
if kv_params and conversation_id:
|
||||
kv_params["remote_host"] = client.host
|
||||
kv_cache.put(conversation_id, kv_params)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
yield line + "\n"
|
||||
|
||||
|
||||
# FastAPI application
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize HTTP clients for P and D instances."""
|
||||
app.state.prefill_clients: list[ServiceClient] = []
|
||||
app.state.decode_clients: list[ServiceClient] = []
|
||||
|
||||
for i, (host, port) in enumerate(global_args.prefiller_instances):
|
||||
app.state.prefill_clients.append(
|
||||
ServiceClient(
|
||||
client=httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=f"http://{host}:{port}/v1",
|
||||
),
|
||||
host=host,
|
||||
port=port,
|
||||
id=i,
|
||||
)
|
||||
)
|
||||
|
||||
for i, (host, port) in enumerate(global_args.decoder_instances):
|
||||
app.state.decode_clients.append(
|
||||
ServiceClient(
|
||||
client=httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=f"http://{host}:{port}/v1",
|
||||
),
|
||||
host=host,
|
||||
port=port,
|
||||
id=i,
|
||||
)
|
||||
)
|
||||
|
||||
app.state.prefill_iter = itertools.cycle(range(len(app.state.prefill_clients)))
|
||||
app.state.decode_iter = itertools.cycle(range(len(app.state.decode_clients)))
|
||||
|
||||
logger.info(
|
||||
"Ready: %d prefill, %d decode instances",
|
||||
len(app.state.prefill_clients),
|
||||
len(app.state.decode_clients),
|
||||
)
|
||||
yield
|
||||
|
||||
for sc in app.state.prefill_clients + app.state.decode_clients:
|
||||
await sc.client.aclose()
|
||||
|
||||
|
||||
app = FastAPI(title="Disaggregated P/D Proxy (Multi-turn)", lifespan=lifespan)
|
||||
|
||||
|
||||
def _next_client(app_state, role: str) -> ServiceClient:
|
||||
if role == "prefill":
|
||||
return app_state.prefill_clients[next(app_state.prefill_iter)]
|
||||
return app_state.decode_clients[next(app_state.decode_iter)]
|
||||
|
||||
|
||||
# Request handler
|
||||
async def _handle_request(api_path: str, request: Request):
|
||||
"""Core request handler for both /v1/chat/completions and /v1/completions."""
|
||||
req_data = await request.json()
|
||||
request_id = str(uuid.uuid4())
|
||||
conversation_id: str = req_data.pop("conversation_id", "")
|
||||
client_wants_stream = req_data.get("stream", False)
|
||||
|
||||
if not conversation_id:
|
||||
logger.warning(
|
||||
"[%s] No conversation_id provided — KV cache reuse disabled "
|
||||
"for this request. Add a 'conversation_id' field to enable "
|
||||
"cross-turn KV sharing.",
|
||||
request_id,
|
||||
)
|
||||
|
||||
# Step 1: Look up cached D blocks from the previous turn
|
||||
cached_kv = kv_cache.get(conversation_id) if conversation_id else None
|
||||
|
||||
if cached_kv:
|
||||
# Tell P to read D's blocks (bidirectional transfer)
|
||||
cached_kv["do_remote_decode"] = True
|
||||
cached_kv["do_remote_prefill"] = False
|
||||
req_data["kv_transfer_params"] = cached_kv
|
||||
logger.info(
|
||||
"[%s] conv=%s: sending D's cached blocks to P (remote_request_id=%s)",
|
||||
request_id,
|
||||
conversation_id,
|
||||
cached_kv.get("remote_request_id"),
|
||||
)
|
||||
else:
|
||||
# No cached blocks — P recomputes from scratch
|
||||
req_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
}
|
||||
logger.info("[%s] conv=%s: cache MISS", request_id, conversation_id)
|
||||
|
||||
# Step 2: Send to Prefill node (non-streaming, max_tokens=1)
|
||||
prefill_client = _next_client(request.app.state, "prefill")
|
||||
t0 = time.time()
|
||||
prefill_resp = await _send_to_prefill(
|
||||
prefill_client,
|
||||
api_path,
|
||||
req_data,
|
||||
request_id,
|
||||
)
|
||||
logger.info(
|
||||
"[%s] Prefill done in %.0fms",
|
||||
request_id,
|
||||
(time.time() - t0) * 1000,
|
||||
)
|
||||
|
||||
# Attach P's kv_transfer_params for D to read P's blocks
|
||||
p_kv_params = prefill_resp.get("kv_transfer_params", {})
|
||||
if p_kv_params:
|
||||
p_kv_params["remote_host"] = prefill_client.host
|
||||
req_data["kv_transfer_params"] = p_kv_params
|
||||
|
||||
# Step 3: Stream from Decode node, capturing kv_transfer_params
|
||||
decode_client = _next_client(request.app.state, "decode")
|
||||
|
||||
if client_wants_stream:
|
||||
return StreamingResponse(
|
||||
_stream_from_decode_sse(
|
||||
decode_client,
|
||||
api_path,
|
||||
req_data,
|
||||
request_id,
|
||||
conversation_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
text, finish_reason, _, resp_id, model, created = await _stream_from_decode(
|
||||
decode_client,
|
||||
api_path,
|
||||
req_data,
|
||||
request_id,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
# Build OpenAI-compatible response
|
||||
is_chat = "messages" in req_data
|
||||
if is_chat:
|
||||
body = {
|
||||
"id": resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": created or int(time.time()),
|
||||
"model": model or req_data.get("model", ""),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": None,
|
||||
}
|
||||
else:
|
||||
body = {
|
||||
"id": resp_id,
|
||||
"object": "text_completion",
|
||||
"created": created or int(time.time()),
|
||||
"model": model or req_data.get("model", ""),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"text": text,
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": None,
|
||||
}
|
||||
return JSONResponse(content=body)
|
||||
|
||||
|
||||
# Routes
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
return await _handle_request("/chat/completions", request)
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def completions(request: Request):
|
||||
return await _handle_request("/completions", request)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
evicted = kv_cache.evict_stale()
|
||||
return {
|
||||
"status": "ok",
|
||||
"cached_conversations": kv_cache.size,
|
||||
"evicted_stale": evicted,
|
||||
}
|
||||
|
||||
|
||||
# CLI
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Disaggregated P/D proxy with bidirectional KV transfer",
|
||||
)
|
||||
p.add_argument("--host", default="0.0.0.0")
|
||||
p.add_argument("--port", type=int, default=8000)
|
||||
p.add_argument(
|
||||
"--prefiller-host",
|
||||
"--prefiller-hosts",
|
||||
dest="prefiller_hosts",
|
||||
nargs="+",
|
||||
default=["localhost"],
|
||||
)
|
||||
p.add_argument(
|
||||
"--prefiller-port",
|
||||
"--prefiller-ports",
|
||||
dest="prefiller_ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8100],
|
||||
)
|
||||
p.add_argument(
|
||||
"--decoder-host",
|
||||
"--decoder-hosts",
|
||||
dest="decoder_hosts",
|
||||
nargs="+",
|
||||
default=["localhost"],
|
||||
)
|
||||
p.add_argument(
|
||||
"--decoder-port",
|
||||
"--decoder-ports",
|
||||
dest="decoder_ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8200],
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
p.error("Number of prefiller hosts must match ports")
|
||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||
p.error("Number of decoder hosts must match ports")
|
||||
|
||||
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
@@ -0,0 +1,915 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for bi-directional KV cache transfer between P and D nodes.
|
||||
|
||||
Tests cover the new behaviors added by the bi-directional KV transfer PR:
|
||||
1. P-node scheduler lifecycle: P pulls KV from D using remote_block_ids,
|
||||
eliminating redundant prefill computation in multi-turn conversations.
|
||||
2. P-node metadata: NixlConnectorMetadata correctly populates recv metadata
|
||||
when P pulls KV from D (do_remote_decode=True + remote_block_ids).
|
||||
3. P-node worker: start_load_kv processes reqs_to_recv for KV pull from D.
|
||||
4. D-node request_finished: returns kv_transfer_params with remote_block_ids
|
||||
and remote_num_tokens so P can pull KV in future turns.
|
||||
5. Edge cases:
|
||||
- No double read after reschedule (_remote_blocks_processed flag)
|
||||
- remote_num_tokens bounded by block capacity (num_computed_tokens)
|
||||
- kv_recompute_threshold skips small transfers
|
||||
- P-node holds blocks for D after finishing
|
||||
- Cache MISS first turn falls back to local prefill
|
||||
- Partial remote coverage: P pulls partial, computes the rest
|
||||
- _remote_blocks_processed flag persists across reschedules
|
||||
|
||||
P-node flags: do_remote_prefill=False (prefill locally),
|
||||
do_remote_decode=True (don't decode locally, send KV to D).
|
||||
P pulls KV from D when remote_block_ids is not None and
|
||||
external tokens > 0.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector import (
|
||||
NixlConnector,
|
||||
NixlConnectorMetadata,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
KVConnectorOutput,
|
||||
)
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
from .test_nixl_connector import FakeNixlConnectorWorker, FakeNixlWrapper
|
||||
from .utils import (
|
||||
assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
make_kv_cache_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
# Common extra config for all bi-directional KV transfer tests.
|
||||
BIDIR_KV_EXTRA_CONFIG = {"bidirectional_kv_xfer": True, "kv_recompute_threshold": 0}
|
||||
|
||||
|
||||
# Helpers
|
||||
|
||||
|
||||
def _make_p_node_turn2_request(
|
||||
request_id, block_size, num_tokens, num_remote_blocks=3, remote_num_tokens=None
|
||||
):
|
||||
"""Create a P-node Turn 2 request with remote_block_ids from D."""
|
||||
request = create_request(
|
||||
request_id=request_id,
|
||||
block_size=block_size,
|
||||
num_tokens=num_tokens,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
if remote_num_tokens is None:
|
||||
remote_num_tokens = num_remote_blocks * block_size
|
||||
request.kv_transfer_params["remote_block_ids"] = [list(range(num_remote_blocks))]
|
||||
request.kv_transfer_params["remote_num_tokens"] = remote_num_tokens
|
||||
request.kv_transfer_params["remote_engine_id"] = "decode-engine"
|
||||
request.kv_transfer_params["remote_request_id"] = f"decode-{request_id}"
|
||||
request.kv_transfer_params["remote_host"] = "decode-host"
|
||||
request.kv_transfer_params["remote_port"] = 5678
|
||||
return request
|
||||
|
||||
|
||||
def _make_connector_with_fake_worker(
|
||||
hand_shake_latency=0, cycles_before_done=0, do_handshake=True
|
||||
):
|
||||
"""Create a NixlConnector with FakeNixlConnectorWorker."""
|
||||
vllm_config = create_vllm_config()
|
||||
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=hand_shake_latency,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
assert isinstance(worker.nixl_wrapper, FakeNixlWrapper)
|
||||
worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done)
|
||||
worker.kv_cache_layout = "HND"
|
||||
if do_handshake:
|
||||
remote_agents = worker._nixl_handshake(
|
||||
host="localhost",
|
||||
port=1234,
|
||||
remote_tp_size=1,
|
||||
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
)
|
||||
worker._remote_agents[FakeNixlConnectorWorker.REMOTE_ENGINE_ID] = remote_agents
|
||||
return connector, worker
|
||||
|
||||
|
||||
def _make_p_node_recv_metadata(request_id, local_blocks, remote_blocks):
|
||||
"""Build NixlConnectorMetadata for P-node pulling KV from D."""
|
||||
meta = NixlConnectorMetadata()
|
||||
meta.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=(local_blocks,),
|
||||
kv_transfer_params={
|
||||
"do_remote_prefill": False,
|
||||
"do_remote_decode": True,
|
||||
"remote_block_ids": (remote_blocks,),
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_request_id": f"decode-{request_id}",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
},
|
||||
)
|
||||
return meta
|
||||
|
||||
|
||||
def _do_load_kv(connector, metadata):
|
||||
"""Bind metadata and call start_load_kv."""
|
||||
connector.bind_connector_metadata(metadata)
|
||||
ctx = ForwardContext(no_compile_layers={}, attn_metadata={}, slot_mapping={})
|
||||
connector.start_load_kv(ctx)
|
||||
|
||||
|
||||
# 1. P-node scheduler lifecycle tests
|
||||
|
||||
|
||||
def test_multiturn_lifecycle():
|
||||
"""Full two-turn lifecycle on the P node:
|
||||
Turn 1: P prefills locally (do_remote_prefill=False), sends KV to D
|
||||
(do_remote_decode=True). Finishes LENGTH_CAPPED with remote_block_ids.
|
||||
Turn 2: P receives remote_block_ids from D. P pulls KV from D because
|
||||
remote_block_ids is not None and external tokens > 0. Computes only
|
||||
new tokens, finishes LENGTH_CAPPED."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
|
||||
t1 = create_request(
|
||||
request_id=100, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True
|
||||
)
|
||||
scheduler.add_request(t1)
|
||||
t1_id = t1.request_id
|
||||
so = scheduler.schedule()
|
||||
mro = create_model_runner_output(reqs=[t1])
|
||||
eco = scheduler.update_from_output(so, mro)
|
||||
assert t1.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
kv = eco[0].outputs[0].kv_transfer_params
|
||||
assert kv and sum(len(g) for g in kv["remote_block_ids"]) > 0
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
t2 = _make_p_node_turn2_request(200, BS, int(BS * 2.5))
|
||||
scheduler.add_request(t2)
|
||||
t2_id = t2.request_id
|
||||
so = scheduler.schedule()
|
||||
assert t2.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_recving={t2_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
so = scheduler.schedule()
|
||||
mro = create_model_runner_output(reqs=[t2])
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert t2.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={t1_id, t2_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_first_turn_no_remote_blocks():
|
||||
"""First turn: P has no remote_block_ids from D yet.
|
||||
Standard local prefill, returns kv_transfer_params for future turns."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=3, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
assert req.status != RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
eco = scheduler.update_from_output(so, mro)
|
||||
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
assert eco[0].outputs[0].kv_transfer_params is not None
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_abort_p_side_during_send():
|
||||
"""P-side do_remote_decode=True: blocks held until finished_sending."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=42, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert req_id in scheduler.requests
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
assert req_id in scheduler.requests
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_abort_p_side_non_length_capped():
|
||||
"""P-side abort with non-LENGTH_CAPPED → immediate block free."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=44, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True
|
||||
)
|
||||
req.sampling_params.max_tokens = 100
|
||||
req.max_tokens = 100
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
scheduler.update_from_output(so, mro)
|
||||
scheduler.finish_requests([req_id], RequestStatus.FINISHED_ABORTED)
|
||||
conn = scheduler.connector.connector_scheduler
|
||||
assert req_id in conn._reqs_not_processed
|
||||
assert req_id not in scheduler.requests
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_remote_blocks_exceed_prompt_tokens():
|
||||
"""D provides more remote tokens than P's prompt needs.
|
||||
P caps external tokens to prompt length."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
NUM_TOKENS = int(BS * 2.5)
|
||||
req = _make_p_node_turn2_request(
|
||||
300, BS, NUM_TOKENS, num_remote_blocks=5, remote_num_tokens=5 * BS
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert req.num_computed_tokens == NUM_TOKENS
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
so = scheduler.schedule()
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_p_node_pulls_partial_last_block_from_d():
|
||||
"""D sends remote_block_ids with partially filled last block.
|
||||
remote_num_tokens < len(remote_block_ids) * block_size.
|
||||
P pulls only remote_num_tokens worth of external tokens."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
num_remote_blocks = 3
|
||||
remote_num_tokens = int(BS * 2.5)
|
||||
assert remote_num_tokens < num_remote_blocks * BS
|
||||
NUM_TOKENS = int(BS * 3.5)
|
||||
req = _make_p_node_turn2_request(
|
||||
400,
|
||||
BS,
|
||||
NUM_TOKENS,
|
||||
num_remote_blocks=num_remote_blocks,
|
||||
remote_num_tokens=remote_num_tokens,
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
so = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
# 2. P-node metadata tests
|
||||
|
||||
|
||||
def test_add_new_req_to_recv_populates_remote_meta():
|
||||
"""add_new_req_to_recv correctly populates RemoteMeta for P-node
|
||||
bi-directional KV pull from D."""
|
||||
meta = NixlConnectorMetadata()
|
||||
kv_params = {
|
||||
"remote_block_ids": [[0, 1, 2]],
|
||||
"remote_engine_id": "decode-engine",
|
||||
"remote_request_id": "decode-req-123",
|
||||
"remote_host": "decode-host",
|
||||
"remote_port": 5678,
|
||||
}
|
||||
local_block_ids = ([10, 11, 12],)
|
||||
meta.add_new_req_to_recv(
|
||||
request_id="test-req",
|
||||
local_block_ids=local_block_ids,
|
||||
kv_transfer_params=kv_params,
|
||||
)
|
||||
assert "test-req" in meta.reqs_to_recv
|
||||
rm = meta.reqs_to_recv["test-req"]
|
||||
assert rm.remote is not None
|
||||
assert rm.remote.block_ids == kv_params["remote_block_ids"]
|
||||
assert rm.remote.engine_id == "decode-engine"
|
||||
assert rm.remote.request_id == "decode-req-123"
|
||||
assert rm.remote.host == "decode-host"
|
||||
assert rm.remote.port == 5678
|
||||
assert rm.local_block_ids == local_block_ids
|
||||
|
||||
|
||||
def test_build_connector_meta_recv_entries():
|
||||
"""P-node scheduler: do_remote_decode=True + remote_block_ids →
|
||||
_reqs_need_recv populated, build_connector_meta produces reqs_to_recv."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = _make_p_node_turn2_request(1, BS, int(BS * 2.5))
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
meta = so.kv_connector_metadata
|
||||
assert isinstance(meta, NixlConnectorMetadata)
|
||||
assert req_id in meta.reqs_to_recv
|
||||
rm = meta.reqs_to_recv[req_id]
|
||||
assert rm.remote is not None
|
||||
assert rm.remote.engine_id == "decode-engine"
|
||||
|
||||
|
||||
def test_build_connector_meta_clears_reqs_need_recv():
|
||||
"""After build_connector_meta, _reqs_need_recv is cleared."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = _make_p_node_turn2_request(2, BS, int(BS * 2.5))
|
||||
scheduler.add_request(req)
|
||||
conn = scheduler.connector.connector_scheduler
|
||||
scheduler.schedule()
|
||||
assert len(conn._reqs_need_recv) == 0
|
||||
|
||||
|
||||
def test_build_connector_meta_multiple_requests():
|
||||
"""Multiple P-node requests all included in reqs_to_recv."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
reqs = [_make_p_node_turn2_request(10 + i, BS, int(BS * 2.5)) for i in range(3)]
|
||||
for r in reqs:
|
||||
scheduler.add_request(r)
|
||||
so = scheduler.schedule()
|
||||
meta = so.kv_connector_metadata
|
||||
assert isinstance(meta, NixlConnectorMetadata)
|
||||
assert len(meta.reqs_to_recv) == 3
|
||||
for r in reqs:
|
||||
assert r.request_id in meta.reqs_to_recv
|
||||
|
||||
|
||||
# 3. P-node worker tests (FakeNixlWrapper)
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_p_node_pull_kv_from_d(dist_init):
|
||||
"""P node pulls KV from D via start_load_kv with reqs_to_recv."""
|
||||
connector, worker = _make_connector_with_fake_worker()
|
||||
meta = _make_p_node_recv_metadata("req-p1", [10, 11, 12], [20, 21, 22])
|
||||
_do_load_kv(connector, meta)
|
||||
assert "req-p1" in worker._recving_metadata
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
assert "req-p1" in done_recving
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_p_node_pull_then_send_kv(dist_init):
|
||||
"""Full P-node bi-directional: pull KV from D → prefill →
|
||||
send KV back to D via notification."""
|
||||
connector, worker = _make_connector_with_fake_worker()
|
||||
meta = _make_p_node_recv_metadata("req-p2", [10, 11], [20, 21])
|
||||
_do_load_kv(connector, meta)
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
assert "req-p2" in done_recving
|
||||
worker._reqs_to_send["req-p2"] = time.perf_counter() + 60
|
||||
worker._reqs_to_process.add("req-p2")
|
||||
notif = f"req-p2:{worker.world_size}".encode()
|
||||
orig = worker.nixl_wrapper.get_new_notifs
|
||||
worker.nixl_wrapper.get_new_notifs = lambda: {"agent": [notif]}
|
||||
done_sending, _ = connector.get_finished(finished_req_ids=set())
|
||||
assert "req-p2" in done_sending
|
||||
worker.nixl_wrapper.get_new_notifs = orig
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_p_node_deferred_pull_on_no_handshake(dist_init):
|
||||
"""P defers KV pull when no prior handshake exists."""
|
||||
connector, worker = _make_connector_with_fake_worker(
|
||||
hand_shake_latency=0, do_handshake=False
|
||||
)
|
||||
meta = _make_p_node_recv_metadata("req-p3", [10, 11], [20, 21])
|
||||
_do_load_kv(connector, meta)
|
||||
assert "req-p3" in worker._recving_metadata
|
||||
timeout = 3.0
|
||||
start = time.perf_counter()
|
||||
while time.perf_counter() - start < timeout:
|
||||
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||
ctx = ForwardContext(no_compile_layers={}, attn_metadata={}, slot_mapping={})
|
||||
connector.start_load_kv(ctx)
|
||||
_, done = connector.get_finished(finished_req_ids=set())
|
||||
if "req-p3" in done:
|
||||
return
|
||||
time.sleep(0.2)
|
||||
raise AssertionError("Transfer did not complete after async handshake")
|
||||
|
||||
|
||||
# 4. D-node request_finished returns kv_transfer_params (new behavior)
|
||||
|
||||
|
||||
def test_d_node_request_finished_returns_kv_params():
|
||||
"""D-node request_finished returns kv_transfer_params with
|
||||
do_remote_decode=True, remote_block_ids, remote_num_tokens
|
||||
for P to pull. These params go directly to P node."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=1, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[], finished_recving={req_id})
|
||||
)
|
||||
so = scheduler.schedule()
|
||||
eco = scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[req], use_eos=True)
|
||||
)
|
||||
assert req.status == RequestStatus.FINISHED_STOPPED
|
||||
kv = eco[0].outputs[0].kv_transfer_params
|
||||
assert kv is not None
|
||||
assert kv["do_remote_decode"] is True
|
||||
assert kv["do_remote_prefill"] is False
|
||||
assert "remote_block_ids" in kv
|
||||
assert "remote_num_tokens" in kv
|
||||
assert kv["remote_num_tokens"] > 0
|
||||
|
||||
|
||||
def test_d_node_request_finished_delays_block_free():
|
||||
"""D-node holds blocks (delay_free=True) until P reads them."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=2, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[], finished_recving={req_id})
|
||||
)
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[req], use_eos=True)
|
||||
)
|
||||
assert req_id in scheduler.requests
|
||||
conn = scheduler.connector.connector_scheduler
|
||||
assert req_id in conn._reqs_need_send
|
||||
|
||||
|
||||
def test_d_node_request_finished_remote_num_tokens():
|
||||
"""D-node kv_transfer_params includes correct remote_num_tokens."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=3, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[], finished_recving={req_id})
|
||||
)
|
||||
so = scheduler.schedule()
|
||||
eco = scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[req], use_eos=True)
|
||||
)
|
||||
kv = eco[0].outputs[0].kv_transfer_params
|
||||
assert kv["remote_num_tokens"] > 0
|
||||
assert sum(len(g) for g in kv["remote_block_ids"]) > 0
|
||||
|
||||
|
||||
def test_d_node_partial_last_block_remote_num_tokens():
|
||||
"""D-node: remote_num_tokens < len(remote_block_ids) * block_size
|
||||
when last block is partially filled."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=5, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[], finished_recving={req_id})
|
||||
)
|
||||
so = scheduler.schedule()
|
||||
eco = scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[req], use_eos=True)
|
||||
)
|
||||
kv = eco[0].outputs[0].kv_transfer_params
|
||||
total_blocks = sum(len(g) for g in kv["remote_block_ids"])
|
||||
assert total_blocks == 3
|
||||
assert kv["remote_num_tokens"] < total_blocks * BS
|
||||
assert kv["remote_num_tokens"] > 0
|
||||
|
||||
|
||||
# 5. Edge case tests
|
||||
|
||||
|
||||
def test_no_double_read_blocks_after_reschedule():
|
||||
"""Edge case 1: update_state_after_alloc called twice for the same
|
||||
bidirectional request (once on initial schedule, once after
|
||||
WAITING_FOR_REMOTE_KVS → reschedule). The _remote_blocks_processed
|
||||
flag must prevent the request from being added to _reqs_need_recv
|
||||
twice, which would cause P to read D's blocks twice."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = _make_p_node_turn2_request(500, BS, int(BS * 2.5))
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
conn = scheduler.connector.connector_scheduler
|
||||
|
||||
# First schedule: request enters WAITING_FOR_REMOTE_KVS,
|
||||
# _reqs_need_recv populated then cleared by build_connector_meta.
|
||||
so = scheduler.schedule()
|
||||
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
meta = so.kv_connector_metadata
|
||||
assert isinstance(meta, NixlConnectorMetadata)
|
||||
assert req_id in meta.reqs_to_recv
|
||||
# _reqs_need_recv should be cleared after build_connector_meta
|
||||
assert len(conn._reqs_need_recv) == 0
|
||||
|
||||
# Simulate recv completion
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
|
||||
# Second schedule after recv: update_state_after_alloc called again.
|
||||
# The _remote_blocks_processed flag should prevent re-entry.
|
||||
so = scheduler.schedule()
|
||||
meta2 = so.kv_connector_metadata
|
||||
assert isinstance(meta2, NixlConnectorMetadata)
|
||||
# Must NOT be in reqs_to_recv again
|
||||
assert req_id not in meta2.reqs_to_recv
|
||||
|
||||
# Clean up
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_remote_num_tokens_bounded_by_blocks():
|
||||
"""Edge case 2: D-node request_finished must return
|
||||
remote_num_tokens <= len(remote_block_ids) * block_size.
|
||||
request.num_tokens includes the last sampled token which has no KV
|
||||
in the cache, so remote_num_tokens must use num_computed_tokens."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=501,
|
||||
block_size=BS,
|
||||
num_tokens=int(BS * 2.5),
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[], finished_recving={req_id})
|
||||
)
|
||||
so = scheduler.schedule()
|
||||
eco = scheduler.update_from_output(
|
||||
so, create_model_runner_output(reqs=[req], use_eos=True)
|
||||
)
|
||||
kv = eco[0].outputs[0].kv_transfer_params
|
||||
assert kv is not None
|
||||
total_blocks = sum(len(g) for g in kv["remote_block_ids"])
|
||||
max_tokens_in_blocks = total_blocks * BS
|
||||
assert kv["remote_num_tokens"] <= max_tokens_in_blocks, (
|
||||
f"remote_num_tokens ({kv['remote_num_tokens']}) exceeds "
|
||||
f"block capacity ({max_tokens_in_blocks})"
|
||||
)
|
||||
assert kv["remote_num_tokens"] > 0
|
||||
|
||||
|
||||
def test_kv_recompute_threshold_skips_small_transfer():
|
||||
"""Edge case 3: When remote tokens are below kv_recompute_threshold,
|
||||
P should skip the remote pull and compute locally instead of
|
||||
entering WAITING_FOR_REMOTE_KVS."""
|
||||
threshold = 256
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config={
|
||||
"bidirectional_kv_xfer": True,
|
||||
"kv_recompute_threshold": threshold,
|
||||
},
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
|
||||
# Create request where remote tokens (48) < threshold (256)
|
||||
req = _make_p_node_turn2_request(
|
||||
502,
|
||||
BS,
|
||||
int(BS * 2.5),
|
||||
num_remote_blocks=3,
|
||||
remote_num_tokens=3 * BS,
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
so = scheduler.schedule()
|
||||
# Should NOT enter WAITING_FOR_REMOTE_KVS — threshold not met
|
||||
assert req.status != RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert req.status == RequestStatus.RUNNING
|
||||
|
||||
# Clean up
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req.request_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_p_node_finished_holds_blocks_for_d():
|
||||
"""Edge case 4: P-node finishes with FINISHED_LENGTH_CAPPED and
|
||||
do_remote_decode=True. P must hold blocks (delay_free=True) and
|
||||
return kv_transfer_params with do_remote_prefill=True so D can
|
||||
read P's blocks."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=503,
|
||||
block_size=BS,
|
||||
num_tokens=int(BS * 2.5),
|
||||
do_remote_decode=True,
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
eco = scheduler.update_from_output(so, mro)
|
||||
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
kv = eco[0].outputs[0].kv_transfer_params
|
||||
assert kv is not None
|
||||
# P-node finished: should tell D to pull (do_remote_prefill=True)
|
||||
assert kv["do_remote_prefill"] is True
|
||||
assert kv["do_remote_decode"] is False
|
||||
assert "remote_block_ids" in kv
|
||||
assert sum(len(g) for g in kv["remote_block_ids"]) > 0
|
||||
# Blocks should be held (request still tracked)
|
||||
assert req_id in scheduler.requests
|
||||
|
||||
# Clean up: simulate D reading and notifying
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_cache_miss_first_turn_no_remote_pull():
|
||||
"""Edge case 5: First turn with do_remote_decode=True but no
|
||||
remote_block_ids (cache MISS). P should prefill locally with
|
||||
num_external_tokens=0 and not enter WAITING_FOR_REMOTE_KVS."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = create_request(
|
||||
request_id=504,
|
||||
block_size=BS,
|
||||
num_tokens=int(BS * 2.5),
|
||||
do_remote_decode=True,
|
||||
)
|
||||
# No remote_block_ids set — this is a cache MISS
|
||||
assert req.kv_transfer_params.get("remote_block_ids") is None
|
||||
scheduler.add_request(req)
|
||||
so = scheduler.schedule()
|
||||
# Should NOT wait for remote KVs
|
||||
assert req.status != RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert req.status == RequestStatus.RUNNING
|
||||
|
||||
# Clean up
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
scheduler.update_from_output(so, mro)
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req.request_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_partial_remote_tokens_less_than_prompt():
|
||||
"""Edge case 6: D's remote_num_tokens covers only part of P's
|
||||
prompt. P should pull remote_num_tokens worth of external tokens
|
||||
and compute the rest locally."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
NUM_TOKENS = int(BS * 4.5) # 72 tokens
|
||||
# D provides only 2 blocks (32 tokens) out of 72
|
||||
req = _make_p_node_turn2_request(
|
||||
505,
|
||||
BS,
|
||||
NUM_TOKENS,
|
||||
num_remote_blocks=2,
|
||||
remote_num_tokens=2 * BS,
|
||||
)
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
so = scheduler.schedule()
|
||||
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
# num_computed_tokens should reflect the external tokens pulled
|
||||
# (capped to remote_num_tokens, not full prompt)
|
||||
assert req.num_computed_tokens < NUM_TOKENS
|
||||
|
||||
# Complete the transfer and finish
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
so = scheduler.schedule()
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_remote_blocks_processed_flag_persists():
|
||||
"""Edge case 7: After recv completes and request is rescheduled,
|
||||
the _remote_blocks_processed flag in kv_transfer_params prevents
|
||||
the bidirectional path from re-entering _reqs_need_recv."""
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
|
||||
)
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
BS = vllm_config.cache_config.block_size
|
||||
req = _make_p_node_turn2_request(506, BS, int(BS * 2.5))
|
||||
scheduler.add_request(req)
|
||||
req_id = req.request_id
|
||||
conn = scheduler.connector.connector_scheduler
|
||||
|
||||
# First schedule → WAITING_FOR_REMOTE_KVS
|
||||
so = scheduler.schedule()
|
||||
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
# Recv completes
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
|
||||
# Verify the flag is set
|
||||
assert req.kv_transfer_params.get("_remote_blocks_processed") is True
|
||||
|
||||
# Next schedule: update_state_after_alloc is called again.
|
||||
# _reqs_need_recv must NOT contain this request.
|
||||
so = scheduler.schedule()
|
||||
assert req_id not in conn._reqs_need_recv
|
||||
meta = so.kv_connector_metadata
|
||||
assert isinstance(meta, NixlConnectorMetadata)
|
||||
assert req_id not in meta.reqs_to_recv
|
||||
|
||||
# Clean up
|
||||
mro = create_model_runner_output(reqs=[req])
|
||||
scheduler.update_from_output(so, mro)
|
||||
so = scheduler.schedule()
|
||||
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
so = scheduler.schedule()
|
||||
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
|
||||
scheduler.update_from_output(so, mro)
|
||||
assert_scheduler_empty(scheduler)
|
||||
@@ -119,6 +119,30 @@ class NixlConnectorScheduler:
|
||||
for n_tokens, block_size in sw_sizes_tokens
|
||||
]
|
||||
|
||||
# Threshold to decide whether to compute kv cache locally
|
||||
# or pull from a remote node: minimum number of remote
|
||||
# tokens to amortize the xfer latencies
|
||||
self.kv_recompute_threshold: int = int(
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"kv_recompute_threshold", 64
|
||||
)
|
||||
)
|
||||
|
||||
# Bi-directional KV transfer feature supports KV block
|
||||
# transfers from D node to P node
|
||||
self.is_bidirectional_kv_xfer_enabled = (
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"bidirectional_kv_xfer", False
|
||||
)
|
||||
)
|
||||
|
||||
if self.is_bidirectional_kv_xfer_enabled and self.kv_recompute_threshold > 0:
|
||||
logger.info(
|
||||
"Bidirectional KV transfer is enabled and the kv "
|
||||
"recompute threshold is set to %d tokens",
|
||||
self.kv_recompute_threshold,
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
self._stop_event.set()
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
@@ -298,6 +322,44 @@ class NixlConnectorScheduler:
|
||||
if params is not None and params.get("do_remote_decode") and self._has_mamba:
|
||||
self._truncate_mamba_request_for_prefill(request)
|
||||
|
||||
if (
|
||||
params is not None
|
||||
and params.get("do_remote_decode")
|
||||
and params.get("remote_block_ids")
|
||||
and all(
|
||||
p in params
|
||||
for p in (
|
||||
"remote_engine_id",
|
||||
"remote_request_id",
|
||||
"remote_host",
|
||||
"remote_port",
|
||||
)
|
||||
)
|
||||
):
|
||||
# Decode node has kv blocks for part of prefill request, so, provide them
|
||||
# as an external token count to scheduler.
|
||||
# The tokens will be loaded if not already present
|
||||
# in the prefill node local cache
|
||||
remote_num_tokens = params.get("remote_num_tokens") or 0
|
||||
count = (
|
||||
min(remote_num_tokens, request.num_prompt_tokens) - num_computed_tokens
|
||||
)
|
||||
if count > 0:
|
||||
# Check kv_recompute_threshold: skip pull if
|
||||
# remote tokens are below the threshold.
|
||||
if (
|
||||
self.kv_recompute_threshold > 0
|
||||
and count < self.kv_recompute_threshold
|
||||
):
|
||||
logger.debug(
|
||||
"Skipping remote pull for %s: %d remote tokens < threshold %d",
|
||||
request.request_id,
|
||||
count,
|
||||
self.kv_recompute_threshold,
|
||||
)
|
||||
return 0, False
|
||||
return count, True
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
@@ -315,13 +377,19 @@ class NixlConnectorScheduler:
|
||||
if not params:
|
||||
return
|
||||
|
||||
if params.get("do_remote_decode"):
|
||||
if params.get("do_remote_decode") or (
|
||||
params.get("do_remote_prefill") and self.is_bidirectional_kv_xfer_enabled
|
||||
):
|
||||
self._reqs_in_batch.add(request.request_id)
|
||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||
# NOTE: when accelerator is not directly supported by Nixl,
|
||||
# prefilled blocks need to be saved to host memory before transfer.
|
||||
self._reqs_need_save[request.request_id] = request
|
||||
elif params.get("do_remote_prefill"):
|
||||
elif params.get("do_remote_prefill") or (
|
||||
params.get("do_remote_decode")
|
||||
and self.is_bidirectional_kv_xfer_enabled
|
||||
and not params.get("_remote_blocks_processed")
|
||||
):
|
||||
if params.get("remote_block_ids"):
|
||||
if all(
|
||||
p in params
|
||||
@@ -333,8 +401,8 @@ class NixlConnectorScheduler:
|
||||
)
|
||||
):
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
# a full prefix cache hit on the local node. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the remote node.
|
||||
|
||||
unhashed_local_block_ids: BlockIds = (
|
||||
blocks.get_unhashed_block_ids_all_groups()
|
||||
@@ -362,6 +430,7 @@ class NixlConnectorScheduler:
|
||||
assert num_external_tokens == 0
|
||||
# Only trigger 1 KV transfer per request.
|
||||
params["do_remote_prefill"] = False
|
||||
params["_remote_blocks_processed"] = True
|
||||
|
||||
def _build_save_meta(
|
||||
self,
|
||||
@@ -450,6 +519,9 @@ class NixlConnectorScheduler:
|
||||
if not params:
|
||||
return False, None
|
||||
|
||||
is_p_node = bool(params.get("do_remote_decode"))
|
||||
is_d_node = not is_p_node
|
||||
|
||||
if params.get("do_remote_prefill"):
|
||||
# If do_remote_prefill is still True when the request is finished,
|
||||
# update_state_after_alloc must not have been called (the request
|
||||
@@ -461,9 +533,13 @@ class NixlConnectorScheduler:
|
||||
params["do_remote_prefill"] = False
|
||||
return False, None
|
||||
|
||||
if not params.get("do_remote_decode"):
|
||||
if is_d_node and not self.is_bidirectional_kv_xfer_enabled:
|
||||
return False, None
|
||||
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
|
||||
|
||||
if request.status not in (
|
||||
RequestStatus.FINISHED_LENGTH_CAPPED,
|
||||
RequestStatus.FINISHED_STOPPED,
|
||||
):
|
||||
# Also include the case of a P/D Prefill request with immediate
|
||||
# block free (eg abort). Stop tracking this request.
|
||||
self._reqs_not_processed.add(request.request_id)
|
||||
@@ -474,7 +550,7 @@ class NixlConnectorScheduler:
|
||||
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||
# remove the conditional below
|
||||
delay_free_blocks = any(len(group) > 0 for group in block_ids)
|
||||
|
||||
remote_num_tokens = 0
|
||||
if delay_free_blocks:
|
||||
# Prefill request on remote. It will be read from D upon completion
|
||||
logger.debug(
|
||||
@@ -492,13 +568,16 @@ class NixlConnectorScheduler:
|
||||
# Here we "unpad" blocks to send the actual remote blocks to be read.
|
||||
block_ids = self.get_sw_clipped_blocks(block_ids)
|
||||
|
||||
remote_num_tokens = request.num_computed_tokens
|
||||
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
do_remote_prefill=is_p_node,
|
||||
do_remote_decode=is_d_node,
|
||||
remote_block_ids=block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_request_id=request.request_id,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
|
||||
remote_num_tokens=remote_num_tokens,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user