[P/D] Prefill compute optimizations with bi-directional KV cache transfers between P and D nodes (#32553)

Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
This commit is contained in:
snadampal
2026-04-30 03:14:20 -07:00
committed by GitHub
parent efdc95674d
commit 3179e53135
4 changed files with 1570 additions and 9 deletions
@@ -217,6 +217,7 @@ async def send_request(
min_tokens: int | None = None,
max_tokens: int | None = None,
timeout_sec: int = 120,
conversation_id: str | None = None,
) -> ServerResponse:
payload = {
"model": model,
@@ -225,6 +226,9 @@ async def send_request(
"temperature": 0.0,
}
if conversation_id is not None:
payload["conversation_id"] = conversation_id
if stream:
payload["stream"] = True
payload["stream_options"] = {"include_usage": False}
@@ -419,6 +423,7 @@ async def send_turn(
min_tokens,
max_tokens,
req_args.timeout_sec,
conversation_id=conv_id,
)
if response.valid is False:
@@ -0,0 +1,562 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Disaggregated Prefill/Decode Proxy with Bidirectional KV Transfer
This proxy sits between clients and a vLLM Prefill/Decode (P/D) deployment,
routing multi-turn chat requests so that each turn reuses KV cache blocks
from the previous turn's Decode node via bidirectional KV transfer.
Architecture:
Client ──► Proxy ──► Prefill (P) ──► Decode (D)
│ │ │
│ kv_transfer_params flow: │
│ D finish ──► proxy caches │
│ next turn ──► proxy sends │
│ cached D blocks to P ──► │
│ P reads D blocks (bidir) │
│ P sends its blocks to D │
Per-request flow:
1. Client sends chat/completions request to proxy.
2. Proxy looks up cached D block info from the previous turn
(keyed by conversation_id).
3. If cache hit, proxy attaches D's block info to the request
so P can read D's KV blocks instead of recomputing.
4. Proxy sends request to P (max_tokens=1, non-streaming).
5. P returns kv_transfer_params with its own block info.
6. Proxy forwards request + P's block info to D (streaming).
7. D streams the response. The final chunk includes D's
kv_transfer_params, which the proxy caches for the next turn.
8. Proxy returns D's response to the client.
Conversation isolation:
Each request must include a ``conversation_id`` field (top-level in
the JSON body) to scope the KV cache across turns. Without it, the
proxy cannot link turns and falls back to no-cache behavior.
Usage:
python disagg_proxy_multiturn.py \\
--host 0.0.0.0 --port 8000 \\
--prefiller-host 10.0.0.1 --prefiller-port 8100 \\
--decoder-host 10.0.0.2 --decoder-port 8200
Dependencies:
pip install fastapi uvicorn httpx
"""
from __future__ import annotations
import argparse
import itertools
import json
import logging
import os
import time
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
# Logging
logging.basicConfig(
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger("disagg_proxy")
# Data structures
@dataclass
class CachedKVEntry:
"""KV transfer parameters cached from D's response for one turn."""
kv_transfer_params: dict[str, Any]
timestamp: float = field(default_factory=time.time)
class ConversationKVCache:
"""Per-conversation KV block cache.
Each conversation is identified by a ``conversation_id`` supplied by
the client. After D finishes a turn, its ``kv_transfer_params`` are
stored here. On the next turn, the proxy retrieves them so P can
read D's blocks via bidirectional KV transfer.
"""
def __init__(self, ttl_seconds: float = 600.0) -> None:
self._store: dict[str, CachedKVEntry] = {}
self._ttl = ttl_seconds
def get(self, conversation_id: str) -> dict[str, Any] | None:
"""Retrieve and consume cached KV params for a conversation.
Returns a *copy* of the kv_transfer_params dict, or None.
The entry is removed after retrieval (single-use).
"""
entry = self._store.pop(conversation_id, None)
if entry is None:
return None
age = time.time() - entry.timestamp
if age > self._ttl:
logger.info(
"conv=%s: stale cache entry (age=%.1fs > ttl=%.1fs), discarding",
conversation_id,
age,
self._ttl,
)
return None
logger.info(
"conv=%s: cache HIT (age=%.1fs)",
conversation_id,
age,
)
return dict(entry.kv_transfer_params)
def put(self, conversation_id: str, kv_params: dict[str, Any]) -> None:
"""Store D's kv_transfer_params for a conversation."""
self._store[conversation_id] = CachedKVEntry(
kv_transfer_params=dict(kv_params), # defensive copy
)
logger.info(
"conv=%s: cached D blocks (remote_request_id=%s, blocks=%d)",
conversation_id,
kv_params.get("remote_request_id", "?"),
len(kv_params.get("remote_block_ids", [[]])[0])
if kv_params.get("remote_block_ids")
else 0,
)
def evict_stale(self) -> int:
"""Remove entries older than TTL. Returns count of evicted entries."""
now = time.time()
stale = [
cid
for cid, entry in self._store.items()
if now - entry.timestamp > self._ttl
]
for cid in stale:
del self._store[cid]
return len(stale)
@property
def size(self) -> int:
return len(self._store)
# Global state
kv_cache = ConversationKVCache(
ttl_seconds=450.0
) # Must be < VLLM_NIXL_ABORT_REQUEST_TIMEOUT (480s)
# Service client helpers
@dataclass
class ServiceClient:
"""Wrapper around an httpx.AsyncClient for a P or D instance."""
client: httpx.AsyncClient
host: str
port: int
id: int
def _make_headers(request_id: str) -> dict[str, str]:
"""Build HTTP headers for upstream requests."""
headers = {"X-Request-Id": request_id}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
async def _send_to_prefill(
client: ServiceClient,
endpoint: str,
req_data: dict[str, Any],
request_id: str,
) -> dict[str, Any]:
"""Send a non-streaming prefill request (max_tokens=1).
Returns the JSON response from P, which includes kv_transfer_params.
"""
payload = req_data.copy()
payload["stream"] = False
payload["max_tokens"] = 1
payload.pop("max_completion_tokens", None)
payload.pop("min_tokens", None)
payload.pop("stream_options", None)
resp = await client.client.post(
endpoint,
json=payload,
headers=_make_headers(request_id),
)
resp.raise_for_status()
return resp.json()
async def _stream_from_decode(
client: ServiceClient,
endpoint: str,
req_data: dict[str, Any],
request_id: str,
conversation_id: str,
) -> tuple[str, str | None, dict[str, Any] | None, str, str | None, int | None]:
"""Stream response from D, capturing text and kv_transfer_params.
Returns (collected_text, finish_reason, kv_params, response_id, created).
Also stores kv_params in the conversation cache.
"""
payload = req_data.copy()
payload["stream"] = True
collected_text = ""
finish_reason: str | None = None
response_id: str | None = None
model_name: str | None = None
created: int | None = None
captured_kv: dict[str, Any] | None = None
async with client.client.stream(
"POST",
endpoint,
json=payload,
headers=_make_headers(request_id),
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line or not line.startswith("data: "):
continue
if line == "data: [DONE]":
break
try:
chunk = json.loads(line[6:])
except json.JSONDecodeError:
continue
if response_id is None:
response_id = chunk.get("id")
model_name = chunk.get("model")
created = chunk.get("created")
for choice in chunk.get("choices", []):
collected_text += choice.get("text", "")
delta = choice.get("delta", {})
collected_text += delta.get("content", "")
if choice.get("finish_reason"):
finish_reason = choice["finish_reason"]
kv_params = chunk.get("kv_transfer_params")
if kv_params:
kv_params["remote_host"] = client.host
captured_kv = kv_params
if conversation_id:
kv_cache.put(conversation_id, kv_params)
return (
collected_text,
finish_reason,
captured_kv,
response_id or request_id,
model_name,
created,
)
async def _stream_from_decode_sse(
client: ServiceClient,
endpoint: str,
req_data: dict[str, Any],
request_id: str,
conversation_id: str,
):
"""Yield SSE chunks from D to the client, capturing kv_transfer_params."""
payload = req_data.copy()
payload["stream"] = True
async with client.client.stream(
"POST",
endpoint,
json=payload,
headers=_make_headers(request_id),
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line:
yield "\n"
continue
if line.startswith("data: ") and line != "data: [DONE]":
try:
chunk = json.loads(line[6:])
kv_params = chunk.get("kv_transfer_params")
if kv_params and conversation_id:
kv_params["remote_host"] = client.host
kv_cache.put(conversation_id, kv_params)
except json.JSONDecodeError:
pass
yield line + "\n"
# FastAPI application
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize HTTP clients for P and D instances."""
app.state.prefill_clients: list[ServiceClient] = []
app.state.decode_clients: list[ServiceClient] = []
for i, (host, port) in enumerate(global_args.prefiller_instances):
app.state.prefill_clients.append(
ServiceClient(
client=httpx.AsyncClient(
timeout=None,
base_url=f"http://{host}:{port}/v1",
),
host=host,
port=port,
id=i,
)
)
for i, (host, port) in enumerate(global_args.decoder_instances):
app.state.decode_clients.append(
ServiceClient(
client=httpx.AsyncClient(
timeout=None,
base_url=f"http://{host}:{port}/v1",
),
host=host,
port=port,
id=i,
)
)
app.state.prefill_iter = itertools.cycle(range(len(app.state.prefill_clients)))
app.state.decode_iter = itertools.cycle(range(len(app.state.decode_clients)))
logger.info(
"Ready: %d prefill, %d decode instances",
len(app.state.prefill_clients),
len(app.state.decode_clients),
)
yield
for sc in app.state.prefill_clients + app.state.decode_clients:
await sc.client.aclose()
app = FastAPI(title="Disaggregated P/D Proxy (Multi-turn)", lifespan=lifespan)
def _next_client(app_state, role: str) -> ServiceClient:
if role == "prefill":
return app_state.prefill_clients[next(app_state.prefill_iter)]
return app_state.decode_clients[next(app_state.decode_iter)]
# Request handler
async def _handle_request(api_path: str, request: Request):
"""Core request handler for both /v1/chat/completions and /v1/completions."""
req_data = await request.json()
request_id = str(uuid.uuid4())
conversation_id: str = req_data.pop("conversation_id", "")
client_wants_stream = req_data.get("stream", False)
if not conversation_id:
logger.warning(
"[%s] No conversation_id provided — KV cache reuse disabled "
"for this request. Add a 'conversation_id' field to enable "
"cross-turn KV sharing.",
request_id,
)
# Step 1: Look up cached D blocks from the previous turn
cached_kv = kv_cache.get(conversation_id) if conversation_id else None
if cached_kv:
# Tell P to read D's blocks (bidirectional transfer)
cached_kv["do_remote_decode"] = True
cached_kv["do_remote_prefill"] = False
req_data["kv_transfer_params"] = cached_kv
logger.info(
"[%s] conv=%s: sending D's cached blocks to P (remote_request_id=%s)",
request_id,
conversation_id,
cached_kv.get("remote_request_id"),
)
else:
# No cached blocks — P recomputes from scratch
req_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
logger.info("[%s] conv=%s: cache MISS", request_id, conversation_id)
# Step 2: Send to Prefill node (non-streaming, max_tokens=1)
prefill_client = _next_client(request.app.state, "prefill")
t0 = time.time()
prefill_resp = await _send_to_prefill(
prefill_client,
api_path,
req_data,
request_id,
)
logger.info(
"[%s] Prefill done in %.0fms",
request_id,
(time.time() - t0) * 1000,
)
# Attach P's kv_transfer_params for D to read P's blocks
p_kv_params = prefill_resp.get("kv_transfer_params", {})
if p_kv_params:
p_kv_params["remote_host"] = prefill_client.host
req_data["kv_transfer_params"] = p_kv_params
# Step 3: Stream from Decode node, capturing kv_transfer_params
decode_client = _next_client(request.app.state, "decode")
if client_wants_stream:
return StreamingResponse(
_stream_from_decode_sse(
decode_client,
api_path,
req_data,
request_id,
conversation_id,
),
media_type="text/event-stream",
)
text, finish_reason, _, resp_id, model, created = await _stream_from_decode(
decode_client,
api_path,
req_data,
request_id,
conversation_id,
)
# Build OpenAI-compatible response
is_chat = "messages" in req_data
if is_chat:
body = {
"id": resp_id,
"object": "chat.completion",
"created": created or int(time.time()),
"model": model or req_data.get("model", ""),
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text},
"finish_reason": finish_reason,
}
],
"usage": None,
}
else:
body = {
"id": resp_id,
"object": "text_completion",
"created": created or int(time.time()),
"model": model or req_data.get("model", ""),
"choices": [
{
"index": 0,
"text": text,
"logprobs": None,
"finish_reason": finish_reason,
}
],
"usage": None,
}
return JSONResponse(content=body)
# Routes
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
return await _handle_request("/chat/completions", request)
@app.post("/v1/completions")
async def completions(request: Request):
return await _handle_request("/completions", request)
@app.get("/health")
async def health():
evicted = kv_cache.evict_stale()
return {
"status": "ok",
"cached_conversations": kv_cache.size,
"evicted_stale": evicted,
}
# CLI
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Disaggregated P/D proxy with bidirectional KV transfer",
)
p.add_argument("--host", default="0.0.0.0")
p.add_argument("--port", type=int, default=8000)
p.add_argument(
"--prefiller-host",
"--prefiller-hosts",
dest="prefiller_hosts",
nargs="+",
default=["localhost"],
)
p.add_argument(
"--prefiller-port",
"--prefiller-ports",
dest="prefiller_ports",
type=int,
nargs="+",
default=[8100],
)
p.add_argument(
"--decoder-host",
"--decoder-hosts",
dest="decoder_hosts",
nargs="+",
default=["localhost"],
)
p.add_argument(
"--decoder-port",
"--decoder-ports",
dest="decoder_ports",
type=int,
nargs="+",
default=[8200],
)
args = p.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
p.error("Number of prefiller hosts must match ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
p.error("Number of decoder hosts must match ports")
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)
@@ -0,0 +1,915 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for bi-directional KV cache transfer between P and D nodes.
Tests cover the new behaviors added by the bi-directional KV transfer PR:
1. P-node scheduler lifecycle: P pulls KV from D using remote_block_ids,
eliminating redundant prefill computation in multi-turn conversations.
2. P-node metadata: NixlConnectorMetadata correctly populates recv metadata
when P pulls KV from D (do_remote_decode=True + remote_block_ids).
3. P-node worker: start_load_kv processes reqs_to_recv for KV pull from D.
4. D-node request_finished: returns kv_transfer_params with remote_block_ids
and remote_num_tokens so P can pull KV in future turns.
5. Edge cases:
- No double read after reschedule (_remote_blocks_processed flag)
- remote_num_tokens bounded by block capacity (num_computed_tokens)
- kv_recompute_threshold skips small transfers
- P-node holds blocks for D after finishing
- Cache MISS first turn falls back to local prefill
- Partial remote coverage: P pulls partial, computes the rest
- _remote_blocks_processed flag persists across reschedules
P-node flags: do_remote_prefill=False (prefill locally),
do_remote_decode=True (don't decode locally, send KV to D).
P pulls KV from D when remote_block_ids is not None and
external tokens > 0.
"""
import copy
import time
from unittest.mock import patch
import pytest
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector import (
NixlConnector,
NixlConnectorMetadata,
)
from vllm.forward_context import ForwardContext
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
)
from vllm.v1.request import RequestStatus
from .test_nixl_connector import FakeNixlConnectorWorker, FakeNixlWrapper
from .utils import (
assert_scheduler_empty,
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
make_kv_cache_config,
)
pytestmark = pytest.mark.cpu_test
# Common extra config for all bi-directional KV transfer tests.
BIDIR_KV_EXTRA_CONFIG = {"bidirectional_kv_xfer": True, "kv_recompute_threshold": 0}
# Helpers
def _make_p_node_turn2_request(
request_id, block_size, num_tokens, num_remote_blocks=3, remote_num_tokens=None
):
"""Create a P-node Turn 2 request with remote_block_ids from D."""
request = create_request(
request_id=request_id,
block_size=block_size,
num_tokens=num_tokens,
do_remote_decode=True,
)
if remote_num_tokens is None:
remote_num_tokens = num_remote_blocks * block_size
request.kv_transfer_params["remote_block_ids"] = [list(range(num_remote_blocks))]
request.kv_transfer_params["remote_num_tokens"] = remote_num_tokens
request.kv_transfer_params["remote_engine_id"] = "decode-engine"
request.kv_transfer_params["remote_request_id"] = f"decode-{request_id}"
request.kv_transfer_params["remote_host"] = "decode-host"
request.kv_transfer_params["remote_port"] = 5678
return request
def _make_connector_with_fake_worker(
hand_shake_latency=0, cycles_before_done=0, do_handshake=True
):
"""Create a NixlConnector with FakeNixlConnectorWorker."""
vllm_config = create_vllm_config()
kv_cache_config = make_kv_cache_config(block_size=16, num_blocks=2)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER, kv_cache_config)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,
connector.engine_id,
hand_shake_latency=hand_shake_latency,
kv_cache_config=kv_cache_config,
)
worker = connector.connector_worker
assert isinstance(worker.nixl_wrapper, FakeNixlWrapper)
worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done)
worker.kv_cache_layout = "HND"
if do_handshake:
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)
worker._remote_agents[FakeNixlConnectorWorker.REMOTE_ENGINE_ID] = remote_agents
return connector, worker
def _make_p_node_recv_metadata(request_id, local_blocks, remote_blocks):
"""Build NixlConnectorMetadata for P-node pulling KV from D."""
meta = NixlConnectorMetadata()
meta.add_new_req_to_recv(
request_id=request_id,
local_block_ids=(local_blocks,),
kv_transfer_params={
"do_remote_prefill": False,
"do_remote_decode": True,
"remote_block_ids": (remote_blocks,),
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"decode-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
},
)
return meta
def _do_load_kv(connector, metadata):
"""Bind metadata and call start_load_kv."""
connector.bind_connector_metadata(metadata)
ctx = ForwardContext(no_compile_layers={}, attn_metadata={}, slot_mapping={})
connector.start_load_kv(ctx)
# 1. P-node scheduler lifecycle tests
def test_multiturn_lifecycle():
"""Full two-turn lifecycle on the P node:
Turn 1: P prefills locally (do_remote_prefill=False), sends KV to D
(do_remote_decode=True). Finishes LENGTH_CAPPED with remote_block_ids.
Turn 2: P receives remote_block_ids from D. P pulls KV from D because
remote_block_ids is not None and external tokens > 0. Computes only
new tokens, finishes LENGTH_CAPPED."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
t1 = create_request(
request_id=100, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True
)
scheduler.add_request(t1)
t1_id = t1.request_id
so = scheduler.schedule()
mro = create_model_runner_output(reqs=[t1])
eco = scheduler.update_from_output(so, mro)
assert t1.status == RequestStatus.FINISHED_LENGTH_CAPPED
kv = eco[0].outputs[0].kv_transfer_params
assert kv and sum(len(g) for g in kv["remote_block_ids"]) > 0
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
t2 = _make_p_node_turn2_request(200, BS, int(BS * 2.5))
scheduler.add_request(t2)
t2_id = t2.request_id
so = scheduler.schedule()
assert t2.status == RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_recving={t2_id})
scheduler.update_from_output(so, mro)
so = scheduler.schedule()
mro = create_model_runner_output(reqs=[t2])
scheduler.update_from_output(so, mro)
assert t2.status == RequestStatus.FINISHED_LENGTH_CAPPED
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={t1_id, t2_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
def test_first_turn_no_remote_blocks():
"""First turn: P has no remote_block_ids from D yet.
Standard local prefill, returns kv_transfer_params for future turns."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=3, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
assert req.status != RequestStatus.WAITING_FOR_REMOTE_KVS
mro = create_model_runner_output(reqs=[req])
eco = scheduler.update_from_output(so, mro)
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
assert eco[0].outputs[0].kv_transfer_params is not None
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
def test_abort_p_side_during_send():
"""P-side do_remote_decode=True: blocks held until finished_sending."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=42, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
mro = create_model_runner_output(reqs=[req])
scheduler.update_from_output(so, mro)
assert req_id in scheduler.requests
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
assert req_id in scheduler.requests
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
def test_abort_p_side_non_length_capped():
"""P-side abort with non-LENGTH_CAPPED → immediate block free."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=44, block_size=BS, num_tokens=int(BS * 2.5), do_remote_decode=True
)
req.sampling_params.max_tokens = 100
req.max_tokens = 100
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
mro = create_model_runner_output(reqs=[req])
scheduler.update_from_output(so, mro)
scheduler.finish_requests([req_id], RequestStatus.FINISHED_ABORTED)
conn = scheduler.connector.connector_scheduler
assert req_id in conn._reqs_not_processed
assert req_id not in scheduler.requests
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
assert_scheduler_empty(scheduler)
def test_remote_blocks_exceed_prompt_tokens():
"""D provides more remote tokens than P's prompt needs.
P caps external tokens to prompt length."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
NUM_TOKENS = int(BS * 2.5)
req = _make_p_node_turn2_request(
300, BS, NUM_TOKENS, num_remote_blocks=5, remote_num_tokens=5 * BS
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert req.num_computed_tokens == NUM_TOKENS
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
scheduler.update_from_output(so, mro)
so = scheduler.schedule()
mro = create_model_runner_output(reqs=[req])
scheduler.update_from_output(so, mro)
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
def test_p_node_pulls_partial_last_block_from_d():
"""D sends remote_block_ids with partially filled last block.
remote_num_tokens < len(remote_block_ids) * block_size.
P pulls only remote_num_tokens worth of external tokens."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
num_remote_blocks = 3
remote_num_tokens = int(BS * 2.5)
assert remote_num_tokens < num_remote_blocks * BS
NUM_TOKENS = int(BS * 3.5)
req = _make_p_node_turn2_request(
400,
BS,
NUM_TOKENS,
num_remote_blocks=num_remote_blocks,
remote_num_tokens=remote_num_tokens,
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
scheduler.update_from_output(so, mro)
so = scheduler.schedule()
assert len(scheduler.running) == 1
mro = create_model_runner_output(reqs=[req])
scheduler.update_from_output(so, mro)
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
# 2. P-node metadata tests
def test_add_new_req_to_recv_populates_remote_meta():
"""add_new_req_to_recv correctly populates RemoteMeta for P-node
bi-directional KV pull from D."""
meta = NixlConnectorMetadata()
kv_params = {
"remote_block_ids": [[0, 1, 2]],
"remote_engine_id": "decode-engine",
"remote_request_id": "decode-req-123",
"remote_host": "decode-host",
"remote_port": 5678,
}
local_block_ids = ([10, 11, 12],)
meta.add_new_req_to_recv(
request_id="test-req",
local_block_ids=local_block_ids,
kv_transfer_params=kv_params,
)
assert "test-req" in meta.reqs_to_recv
rm = meta.reqs_to_recv["test-req"]
assert rm.remote is not None
assert rm.remote.block_ids == kv_params["remote_block_ids"]
assert rm.remote.engine_id == "decode-engine"
assert rm.remote.request_id == "decode-req-123"
assert rm.remote.host == "decode-host"
assert rm.remote.port == 5678
assert rm.local_block_ids == local_block_ids
def test_build_connector_meta_recv_entries():
"""P-node scheduler: do_remote_decode=True + remote_block_ids →
_reqs_need_recv populated, build_connector_meta produces reqs_to_recv."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = _make_p_node_turn2_request(1, BS, int(BS * 2.5))
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
meta = so.kv_connector_metadata
assert isinstance(meta, NixlConnectorMetadata)
assert req_id in meta.reqs_to_recv
rm = meta.reqs_to_recv[req_id]
assert rm.remote is not None
assert rm.remote.engine_id == "decode-engine"
def test_build_connector_meta_clears_reqs_need_recv():
"""After build_connector_meta, _reqs_need_recv is cleared."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = _make_p_node_turn2_request(2, BS, int(BS * 2.5))
scheduler.add_request(req)
conn = scheduler.connector.connector_scheduler
scheduler.schedule()
assert len(conn._reqs_need_recv) == 0
def test_build_connector_meta_multiple_requests():
"""Multiple P-node requests all included in reqs_to_recv."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
reqs = [_make_p_node_turn2_request(10 + i, BS, int(BS * 2.5)) for i in range(3)]
for r in reqs:
scheduler.add_request(r)
so = scheduler.schedule()
meta = so.kv_connector_metadata
assert isinstance(meta, NixlConnectorMetadata)
assert len(meta.reqs_to_recv) == 3
for r in reqs:
assert r.request_id in meta.reqs_to_recv
# 3. P-node worker tests (FakeNixlWrapper)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper,
)
def test_p_node_pull_kv_from_d(dist_init):
"""P node pulls KV from D via start_load_kv with reqs_to_recv."""
connector, worker = _make_connector_with_fake_worker()
meta = _make_p_node_recv_metadata("req-p1", [10, 11, 12], [20, 21, 22])
_do_load_kv(connector, meta)
assert "req-p1" in worker._recving_metadata
_, done_recving = connector.get_finished(finished_req_ids=set())
assert "req-p1" in done_recving
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper,
)
def test_p_node_pull_then_send_kv(dist_init):
"""Full P-node bi-directional: pull KV from D → prefill →
send KV back to D via notification."""
connector, worker = _make_connector_with_fake_worker()
meta = _make_p_node_recv_metadata("req-p2", [10, 11], [20, 21])
_do_load_kv(connector, meta)
_, done_recving = connector.get_finished(finished_req_ids=set())
assert "req-p2" in done_recving
worker._reqs_to_send["req-p2"] = time.perf_counter() + 60
worker._reqs_to_process.add("req-p2")
notif = f"req-p2:{worker.world_size}".encode()
orig = worker.nixl_wrapper.get_new_notifs
worker.nixl_wrapper.get_new_notifs = lambda: {"agent": [notif]}
done_sending, _ = connector.get_finished(finished_req_ids=set())
assert "req-p2" in done_sending
worker.nixl_wrapper.get_new_notifs = orig
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper,
)
def test_p_node_deferred_pull_on_no_handshake(dist_init):
"""P defers KV pull when no prior handshake exists."""
connector, worker = _make_connector_with_fake_worker(
hand_shake_latency=0, do_handshake=False
)
meta = _make_p_node_recv_metadata("req-p3", [10, 11], [20, 21])
_do_load_kv(connector, meta)
assert "req-p3" in worker._recving_metadata
timeout = 3.0
start = time.perf_counter()
while time.perf_counter() - start < timeout:
connector.bind_connector_metadata(NixlConnectorMetadata())
ctx = ForwardContext(no_compile_layers={}, attn_metadata={}, slot_mapping={})
connector.start_load_kv(ctx)
_, done = connector.get_finished(finished_req_ids=set())
if "req-p3" in done:
return
time.sleep(0.2)
raise AssertionError("Transfer did not complete after async handshake")
# 4. D-node request_finished returns kv_transfer_params (new behavior)
def test_d_node_request_finished_returns_kv_params():
"""D-node request_finished returns kv_transfer_params with
do_remote_decode=True, remote_block_ids, remote_num_tokens
for P to pull. These params go directly to P node."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=1, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
scheduler.update_from_output(
so, create_model_runner_output(reqs=[], finished_recving={req_id})
)
so = scheduler.schedule()
eco = scheduler.update_from_output(
so, create_model_runner_output(reqs=[req], use_eos=True)
)
assert req.status == RequestStatus.FINISHED_STOPPED
kv = eco[0].outputs[0].kv_transfer_params
assert kv is not None
assert kv["do_remote_decode"] is True
assert kv["do_remote_prefill"] is False
assert "remote_block_ids" in kv
assert "remote_num_tokens" in kv
assert kv["remote_num_tokens"] > 0
def test_d_node_request_finished_delays_block_free():
"""D-node holds blocks (delay_free=True) until P reads them."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=2, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
scheduler.update_from_output(
so, create_model_runner_output(reqs=[], finished_recving={req_id})
)
so = scheduler.schedule()
scheduler.update_from_output(
so, create_model_runner_output(reqs=[req], use_eos=True)
)
assert req_id in scheduler.requests
conn = scheduler.connector.connector_scheduler
assert req_id in conn._reqs_need_send
def test_d_node_request_finished_remote_num_tokens():
"""D-node kv_transfer_params includes correct remote_num_tokens."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=3, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
scheduler.update_from_output(
so, create_model_runner_output(reqs=[], finished_recving={req_id})
)
so = scheduler.schedule()
eco = scheduler.update_from_output(
so, create_model_runner_output(reqs=[req], use_eos=True)
)
kv = eco[0].outputs[0].kv_transfer_params
assert kv["remote_num_tokens"] > 0
assert sum(len(g) for g in kv["remote_block_ids"]) > 0
def test_d_node_partial_last_block_remote_num_tokens():
"""D-node: remote_num_tokens < len(remote_block_ids) * block_size
when last block is partially filled."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=5, block_size=BS, num_tokens=int(BS * 2.5), do_remote_prefill=True
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
scheduler.update_from_output(
so, create_model_runner_output(reqs=[], finished_recving={req_id})
)
so = scheduler.schedule()
eco = scheduler.update_from_output(
so, create_model_runner_output(reqs=[req], use_eos=True)
)
kv = eco[0].outputs[0].kv_transfer_params
total_blocks = sum(len(g) for g in kv["remote_block_ids"])
assert total_blocks == 3
assert kv["remote_num_tokens"] < total_blocks * BS
assert kv["remote_num_tokens"] > 0
# 5. Edge case tests
def test_no_double_read_blocks_after_reschedule():
"""Edge case 1: update_state_after_alloc called twice for the same
bidirectional request (once on initial schedule, once after
WAITING_FOR_REMOTE_KVS → reschedule). The _remote_blocks_processed
flag must prevent the request from being added to _reqs_need_recv
twice, which would cause P to read D's blocks twice."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = _make_p_node_turn2_request(500, BS, int(BS * 2.5))
scheduler.add_request(req)
req_id = req.request_id
conn = scheduler.connector.connector_scheduler
# First schedule: request enters WAITING_FOR_REMOTE_KVS,
# _reqs_need_recv populated then cleared by build_connector_meta.
so = scheduler.schedule()
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
meta = so.kv_connector_metadata
assert isinstance(meta, NixlConnectorMetadata)
assert req_id in meta.reqs_to_recv
# _reqs_need_recv should be cleared after build_connector_meta
assert len(conn._reqs_need_recv) == 0
# Simulate recv completion
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
scheduler.update_from_output(so, mro)
# Second schedule after recv: update_state_after_alloc called again.
# The _remote_blocks_processed flag should prevent re-entry.
so = scheduler.schedule()
meta2 = so.kv_connector_metadata
assert isinstance(meta2, NixlConnectorMetadata)
# Must NOT be in reqs_to_recv again
assert req_id not in meta2.reqs_to_recv
# Clean up
mro = create_model_runner_output(reqs=[req])
scheduler.update_from_output(so, mro)
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
def test_remote_num_tokens_bounded_by_blocks():
"""Edge case 2: D-node request_finished must return
remote_num_tokens <= len(remote_block_ids) * block_size.
request.num_tokens includes the last sampled token which has no KV
in the cache, so remote_num_tokens must use num_computed_tokens."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=501,
block_size=BS,
num_tokens=int(BS * 2.5),
do_remote_prefill=True,
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
scheduler.update_from_output(
so, create_model_runner_output(reqs=[], finished_recving={req_id})
)
so = scheduler.schedule()
eco = scheduler.update_from_output(
so, create_model_runner_output(reqs=[req], use_eos=True)
)
kv = eco[0].outputs[0].kv_transfer_params
assert kv is not None
total_blocks = sum(len(g) for g in kv["remote_block_ids"])
max_tokens_in_blocks = total_blocks * BS
assert kv["remote_num_tokens"] <= max_tokens_in_blocks, (
f"remote_num_tokens ({kv['remote_num_tokens']}) exceeds "
f"block capacity ({max_tokens_in_blocks})"
)
assert kv["remote_num_tokens"] > 0
def test_kv_recompute_threshold_skips_small_transfer():
"""Edge case 3: When remote tokens are below kv_recompute_threshold,
P should skip the remote pull and compute locally instead of
entering WAITING_FOR_REMOTE_KVS."""
threshold = 256
vllm_config = create_vllm_config(
kv_connector_extra_config={
"bidirectional_kv_xfer": True,
"kv_recompute_threshold": threshold,
},
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
# Create request where remote tokens (48) < threshold (256)
req = _make_p_node_turn2_request(
502,
BS,
int(BS * 2.5),
num_remote_blocks=3,
remote_num_tokens=3 * BS,
)
scheduler.add_request(req)
so = scheduler.schedule()
# Should NOT enter WAITING_FOR_REMOTE_KVS — threshold not met
assert req.status != RequestStatus.WAITING_FOR_REMOTE_KVS
assert req.status == RequestStatus.RUNNING
# Clean up
mro = create_model_runner_output(reqs=[req])
scheduler.update_from_output(so, mro)
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req.request_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
def test_p_node_finished_holds_blocks_for_d():
"""Edge case 4: P-node finishes with FINISHED_LENGTH_CAPPED and
do_remote_decode=True. P must hold blocks (delay_free=True) and
return kv_transfer_params with do_remote_prefill=True so D can
read P's blocks."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=503,
block_size=BS,
num_tokens=int(BS * 2.5),
do_remote_decode=True,
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
mro = create_model_runner_output(reqs=[req])
eco = scheduler.update_from_output(so, mro)
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
kv = eco[0].outputs[0].kv_transfer_params
assert kv is not None
# P-node finished: should tell D to pull (do_remote_prefill=True)
assert kv["do_remote_prefill"] is True
assert kv["do_remote_decode"] is False
assert "remote_block_ids" in kv
assert sum(len(g) for g in kv["remote_block_ids"]) > 0
# Blocks should be held (request still tracked)
assert req_id in scheduler.requests
# Clean up: simulate D reading and notifying
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
def test_cache_miss_first_turn_no_remote_pull():
"""Edge case 5: First turn with do_remote_decode=True but no
remote_block_ids (cache MISS). P should prefill locally with
num_external_tokens=0 and not enter WAITING_FOR_REMOTE_KVS."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = create_request(
request_id=504,
block_size=BS,
num_tokens=int(BS * 2.5),
do_remote_decode=True,
)
# No remote_block_ids set — this is a cache MISS
assert req.kv_transfer_params.get("remote_block_ids") is None
scheduler.add_request(req)
so = scheduler.schedule()
# Should NOT wait for remote KVs
assert req.status != RequestStatus.WAITING_FOR_REMOTE_KVS
assert req.status == RequestStatus.RUNNING
# Clean up
mro = create_model_runner_output(reqs=[req])
scheduler.update_from_output(so, mro)
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req.request_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
def test_partial_remote_tokens_less_than_prompt():
"""Edge case 6: D's remote_num_tokens covers only part of P's
prompt. P should pull remote_num_tokens worth of external tokens
and compute the rest locally."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
NUM_TOKENS = int(BS * 4.5) # 72 tokens
# D provides only 2 blocks (32 tokens) out of 72
req = _make_p_node_turn2_request(
505,
BS,
NUM_TOKENS,
num_remote_blocks=2,
remote_num_tokens=2 * BS,
)
scheduler.add_request(req)
req_id = req.request_id
so = scheduler.schedule()
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
# num_computed_tokens should reflect the external tokens pulled
# (capped to remote_num_tokens, not full prompt)
assert req.num_computed_tokens < NUM_TOKENS
# Complete the transfer and finish
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
scheduler.update_from_output(so, mro)
so = scheduler.schedule()
mro = create_model_runner_output(reqs=[req])
scheduler.update_from_output(so, mro)
assert req.status == RequestStatus.FINISHED_LENGTH_CAPPED
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
def test_remote_blocks_processed_flag_persists():
"""Edge case 7: After recv completes and request is rescheduled,
the _remote_blocks_processed flag in kv_transfer_params prevents
the bidirectional path from re-entering _reqs_need_recv."""
vllm_config = create_vllm_config(
kv_connector_extra_config=BIDIR_KV_EXTRA_CONFIG,
)
scheduler = create_scheduler(vllm_config)
BS = vllm_config.cache_config.block_size
req = _make_p_node_turn2_request(506, BS, int(BS * 2.5))
scheduler.add_request(req)
req_id = req.request_id
conn = scheduler.connector.connector_scheduler
# First schedule → WAITING_FOR_REMOTE_KVS
so = scheduler.schedule()
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
# Recv completes
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_recving={req_id})
scheduler.update_from_output(so, mro)
# Verify the flag is set
assert req.kv_transfer_params.get("_remote_blocks_processed") is True
# Next schedule: update_state_after_alloc is called again.
# _reqs_need_recv must NOT contain this request.
so = scheduler.schedule()
assert req_id not in conn._reqs_need_recv
meta = so.kv_connector_metadata
assert isinstance(meta, NixlConnectorMetadata)
assert req_id not in meta.reqs_to_recv
# Clean up
mro = create_model_runner_output(reqs=[req])
scheduler.update_from_output(so, mro)
so = scheduler.schedule()
scheduler.update_from_output(so, EMPTY_MODEL_RUNNER_OUTPUT)
so = scheduler.schedule()
mro = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
mro.kv_connector_output = KVConnectorOutput(finished_sending={req_id})
scheduler.update_from_output(so, mro)
assert_scheduler_empty(scheduler)
@@ -119,6 +119,30 @@ class NixlConnectorScheduler:
for n_tokens, block_size in sw_sizes_tokens
]
# Threshold to decide whether to compute kv cache locally
# or pull from a remote node: minimum number of remote
# tokens to amortize the xfer latencies
self.kv_recompute_threshold: int = int(
vllm_config.kv_transfer_config.get_from_extra_config(
"kv_recompute_threshold", 64
)
)
# Bi-directional KV transfer feature supports KV block
# transfers from D node to P node
self.is_bidirectional_kv_xfer_enabled = (
vllm_config.kv_transfer_config.get_from_extra_config(
"bidirectional_kv_xfer", False
)
)
if self.is_bidirectional_kv_xfer_enabled and self.kv_recompute_threshold > 0:
logger.info(
"Bidirectional KV transfer is enabled and the kv "
"recompute threshold is set to %d tokens",
self.kv_recompute_threshold,
)
def shutdown(self):
self._stop_event.set()
if self._nixl_handshake_listener_t is not None:
@@ -298,6 +322,44 @@ class NixlConnectorScheduler:
if params is not None and params.get("do_remote_decode") and self._has_mamba:
self._truncate_mamba_request_for_prefill(request)
if (
params is not None
and params.get("do_remote_decode")
and params.get("remote_block_ids")
and all(
p in params
for p in (
"remote_engine_id",
"remote_request_id",
"remote_host",
"remote_port",
)
)
):
# Decode node has kv blocks for part of prefill request, so, provide them
# as an external token count to scheduler.
# The tokens will be loaded if not already present
# in the prefill node local cache
remote_num_tokens = params.get("remote_num_tokens") or 0
count = (
min(remote_num_tokens, request.num_prompt_tokens) - num_computed_tokens
)
if count > 0:
# Check kv_recompute_threshold: skip pull if
# remote tokens are below the threshold.
if (
self.kv_recompute_threshold > 0
and count < self.kv_recompute_threshold
):
logger.debug(
"Skipping remote pull for %s: %d remote tokens < threshold %d",
request.request_id,
count,
self.kv_recompute_threshold,
)
return 0, False
return count, True
# No remote prefill for this request.
return 0, False
@@ -315,13 +377,19 @@ class NixlConnectorScheduler:
if not params:
return
if params.get("do_remote_decode"):
if params.get("do_remote_decode") or (
params.get("do_remote_prefill") and self.is_bidirectional_kv_xfer_enabled
):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
self._reqs_need_save[request.request_id] = request
elif params.get("do_remote_prefill"):
elif params.get("do_remote_prefill") or (
params.get("do_remote_decode")
and self.is_bidirectional_kv_xfer_enabled
and not params.get("_remote_blocks_processed")
):
if params.get("remote_block_ids"):
if all(
p in params
@@ -333,8 +401,8 @@ class NixlConnectorScheduler:
)
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
# a full prefix cache hit on the local node. We need to call
# send_notif in _read_blocks to free the memory on the remote node.
unhashed_local_block_ids: BlockIds = (
blocks.get_unhashed_block_ids_all_groups()
@@ -362,6 +430,7 @@ class NixlConnectorScheduler:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
params["_remote_blocks_processed"] = True
def _build_save_meta(
self,
@@ -450,6 +519,9 @@ class NixlConnectorScheduler:
if not params:
return False, None
is_p_node = bool(params.get("do_remote_decode"))
is_d_node = not is_p_node
if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
@@ -461,9 +533,13 @@ class NixlConnectorScheduler:
params["do_remote_prefill"] = False
return False, None
if not params.get("do_remote_decode"):
if is_d_node and not self.is_bidirectional_kv_xfer_enabled:
return False, None
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
if request.status not in (
RequestStatus.FINISHED_LENGTH_CAPPED,
RequestStatus.FINISHED_STOPPED,
):
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
@@ -474,7 +550,7 @@ class NixlConnectorScheduler:
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = any(len(group) > 0 for group in block_ids)
remote_num_tokens = 0
if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
logger.debug(
@@ -492,13 +568,16 @@ class NixlConnectorScheduler:
# Here we "unpad" blocks to send the actual remote blocks to be read.
block_ids = self.get_sw_clipped_blocks(block_ids)
remote_num_tokens = request.num_computed_tokens
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
do_remote_prefill=is_p_node,
do_remote_decode=is_d_node,
remote_block_ids=block_ids,
remote_engine_id=self.engine_id,
remote_request_id=request.request_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
remote_num_tokens=remote_num_tokens,
)