[Bugfix][KV Transfer][NIXL] Notify P node on pre-admission rejection to free stranded KV blocks (#41269)

This commit is contained in:
Dao007forever
2026-05-09 22:52:09 -07:00
committed by GitHub
parent fb1ac806c5
commit 3f5bd482f5
11 changed files with 156 additions and 2 deletions
@@ -321,6 +321,34 @@ def test_prompt_less_than_block_size():
assert len(scheduler_output.scheduled_new_reqs) == 0
def test_abort_immediately_remote_prefill_enqueues_empty_recv():
"""A remote-prefill request added with abort_immediately=True should
be added to the scheduler's waiting queue then immediately aborted, so the
NIXL connector's request_finished hook enqueues an empty recv to notify
the prefill instance to free its blocks."""
from vllm.v1.request import RequestStatus
scheduler = create_scheduler(create_vllm_config())
request = create_request(request_id=42, num_tokens=10, do_remote_prefill=True)
assert request.kv_transfer_params is not None
assert request.kv_transfer_params["do_remote_prefill"] is True
# Mimic the EngineCore.add_request path for an abort-immediately req.
scheduler.add_request(request)
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)
scheduler_output = scheduler.schedule()
meta = scheduler_output.kv_connector_metadata
assert isinstance(meta, NixlConnectorMetadata)
assert set(meta.reqs_to_recv) == {request.request_id}
req_meta = meta.reqs_to_recv[request.request_id]
assert req_meta.local_block_ids == []
assert req_meta.remote.request_id == f"prefill-{42}"
# do_remote_prefill is consumed by request_finished to prevent re-issuing.
assert request.kv_transfer_params["do_remote_prefill"] is False
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper",
FakeNixlWrapper,
@@ -525,7 +525,9 @@ class NixlConnectorScheduler:
if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# must have been aborted before it was scheduled, e.g. via the
# abort_immediately path used to clean up KV-transfer requests
# rejected at the D-side serving layer).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
+14
View File
@@ -108,6 +108,20 @@ class EngineClient(ABC):
"""
...
@abstractmethod
async def notify_kv_transfer_request_rejected(
self,
request_id: str,
kv_transfer_params: dict[str, Any],
*,
data_parallel_rank: int | None = None,
) -> None:
"""Notify the engine that a KV-transfer request was rejected before
engine admission, so connector-side cleanup can run (e.g. free
prefill blocks pinned on the P node).
"""
...
@abstractmethod
async def is_tracing_enabled(self) -> bool: ...
@@ -234,6 +234,15 @@ class OpenAIServingChat(OpenAIServing):
for the API specification. This API mimics the OpenAI
Chat Completion API.
"""
return await self._with_kv_transfer_rejection_cleanup(
self._create_chat_completion(request, raw_request), request, raw_request
)
async def _create_chat_completion(
self,
request: ChatCompletionRequest,
raw_request: Request | None = None,
) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse:
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
@@ -118,6 +118,15 @@ class OpenAIServingCompletion(OpenAIServing):
- suffix (the language models we currently support do not support
suffix)
"""
return await self._with_kv_transfer_rejection_cleanup(
self._create_completion(request, raw_request), request, raw_request
)
async def _create_completion(
self,
request: CompletionRequest,
raw_request: Request | None = None,
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
if request.stream and request.use_beam_search:
return self.create_error_response(
"Streaming is not currently supported with beam search"
+39 -1
View File
@@ -4,7 +4,7 @@ import asyncio
import contextlib
import json
import time
from collections.abc import AsyncGenerator, Mapping
from collections.abc import AsyncGenerator, Awaitable, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
@@ -118,6 +118,7 @@ AnyResponse: TypeAlias = (
)
RequestT = TypeVar("RequestT", bound=AnyRequest)
_T = TypeVar("_T")
@dataclass(kw_only=True)
@@ -156,6 +157,9 @@ class OpenAIServing:
self.model_config = engine_client.model_config
self.renderer = engine_client.renderer
self.input_processor = engine_client.input_processor
vllm_config = getattr(engine_client, "vllm_config", None)
kv_transfer_config = getattr(vllm_config, "kv_transfer_config", None)
self.has_kv_connector = kv_transfer_config is not None
# Computed once at startup (cached by ``vllm_config`` identity) and
# stamped on non-streaming responses. Streaming chunks deliberately
@@ -616,6 +620,40 @@ class OpenAIServing:
except ValueError:
return None
async def _with_kv_transfer_rejection_cleanup(
self,
awaitable: Awaitable[_T],
request: ChatCompletionRequest | CompletionRequest | ResponsesRequest,
raw_request: Request | None,
) -> _T:
"""Wrap a `create_*` coroutine so that, if it raises or returns an
ErrorResponse (i.e. the request never reached the engine), the KV
connector is notified to free any pinned remote-prefill blocks."""
kv_transfer_params = self.has_kv_connector and request.kv_transfer_params
if not kv_transfer_params or not kv_transfer_params.get("do_remote_prefill"):
return await awaitable
notify = True
try:
result = await awaitable
if not isinstance(result, ErrorResponse):
notify = False
return result
finally:
if notify:
try:
await self.engine_client.notify_kv_transfer_request_rejected(
request.request_id,
kv_transfer_params,
data_parallel_rank=self._get_data_parallel_rank(raw_request),
)
except Exception:
logger.warning(
"Failed to notify KV connector about rejected request %s",
request.request_id,
exc_info=True,
)
@staticmethod
def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest,
@@ -323,6 +323,17 @@ class OpenAIServingResponses(OpenAIServing):
AsyncGenerator[StreamingResponsesResponse, None]
| ResponsesResponse
| ErrorResponse
):
return await self._with_kv_transfer_rejection_cleanup(
self._create_responses(request, raw_request), request, raw_request
)
async def _create_responses(
self, request: ResponsesRequest, raw_request: Request | None = None
) -> (
AsyncGenerator[StreamingResponsesResponse, None]
| ResponsesResponse
| ErrorResponse
):
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
+6
View File
@@ -122,6 +122,12 @@ class EngineCoreRequest(
reasoning_ended: bool | None = None
reasoning_parser_kwargs: dict[str, Any] | None = None
# If True, the request should be added to the scheduler's waiting queue
# and immediately aborted, so connector-side cleanup runs via the standard
# request_finished hook. Used to free P-side prefill blocks when a
# KV-transfer request is rejected on the D node before engine admission.
abort_immediately: bool = False
@property
def params(self) -> SamplingParams | PoolingParams:
"""Return the processed params (sampling or pooling)."""
+27
View File
@@ -720,6 +720,33 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Aborted request(s) %s.", ",".join(request_ids))
async def notify_kv_transfer_request_rejected(
self,
request_id: str,
kv_transfer_params: dict[str, Any],
*,
data_parallel_rank: int | None = None,
) -> None:
"""Submit a pre-aborted request so the connector's request_finished
hook runs to free any pre-admission KV-transfer resources (e.g. NIXL
prefill blocks pinned on the P node)."""
request = EngineCoreRequest(
request_id=request_id,
prompt_token_ids=[0],
mm_features=None,
sampling_params=SamplingParams(
max_tokens=1,
extra_args={"kv_transfer_params": dict(kv_transfer_params)},
),
pooling_params=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
data_parallel_rank=data_parallel_rank,
abort_immediately=True,
)
await self.engine_core.add_request_async(request)
async def pause_generation(
self,
*,
+4
View File
@@ -344,6 +344,10 @@ class EngineCore:
)
self.scheduler.add_request(request)
if request.abort_immediately:
# Immediately abort so the connector's request_finished hook runs
# to free any pre-admission KV-transfer resources.
self.abort_requests([request.request_id])
def abort_requests(self, request_ids: list[str]):
"""Abort requests from the scheduler."""
+6
View File
@@ -76,6 +76,7 @@ class Request:
resumable: bool = False,
reasoning_ended: bool | None = None,
reasoning_parser_kwargs: dict[str, Any] | None = None,
abort_immediately: bool = False,
) -> None:
self.request_id = request_id
self.client_index = client_index
@@ -182,6 +183,10 @@ class Request:
# None entry in the queue means finished.
self.streaming_queue: deque[StreamingUpdate | None] | None = None
# If True, request should be aborted immediately after being added to
# the scheduler so the connector's request_finished hook runs.
self.abort_immediately = abort_immediately
@classmethod
def from_engine_core_request(
cls,
@@ -206,6 +211,7 @@ class Request:
resumable=request.resumable,
reasoning_ended=request.reasoning_ended,
reasoning_parser_kwargs=request.reasoning_parser_kwargs,
abort_immediately=request.abort_immediately,
)
def append_output_token_ids(