mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[KV Connector] Keep MooncakeStore full hits block-aligned (#43494)
Signed-off-by: Dao Le <daole@inferact.ai> Signed-off-by: Dao Le <Dao007forever@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -204,7 +204,6 @@ the vLLM JSON config.
|
||||
|
||||
- `load_async` (bool): Enable asynchronous loading for better compute-I/O overlap. Default: `true`.
|
||||
- `enable_cross_layers_blocks` (bool): Enable cross-layer block packing for reduced store operations. Default: `false`.
|
||||
- `discard_partial_chunks` (bool): Discard partial block chunks during store. Default: `true`.
|
||||
- `lookup_rpc_port` (int): Custom port for the ZMQ lookup RPC socket. Default: `0`.
|
||||
|
||||
## Notes
|
||||
|
||||
@@ -18,7 +18,6 @@ def _make_bare_scheduler() -> MooncakeStoreScheduler:
|
||||
scheduler.kv_role = "kv_both"
|
||||
scheduler.original_block_size = 16
|
||||
scheduler._block_size = 16
|
||||
scheduler._discard_partial_chunks = True
|
||||
scheduler.load_specs = {}
|
||||
scheduler._preempted_req_ids = set()
|
||||
scheduler._unfinished_request_ids = {"req-0"}
|
||||
@@ -344,3 +343,66 @@ def test_from_request_tracker_no_load_saves_normally():
|
||||
assert req_meta.can_save is True
|
||||
assert req_meta.load_spec is None
|
||||
assert tracker.num_saved_tokens == 48
|
||||
|
||||
|
||||
class _StubLookupClient:
|
||||
def __init__(self, hit_tokens: int) -> None:
|
||||
self._hit_tokens = hit_tokens
|
||||
|
||||
def lookup(self, token_len: int, block_hashes: list[bytes]) -> int:
|
||||
return self._hit_tokens
|
||||
|
||||
|
||||
def test_full_external_hit_keeps_kvpool_cached_tokens_block_aligned():
|
||||
# When the external store hits the entire prompt, scheduler must leave at
|
||||
# least one token uncomputed for sampling but stay on a block boundary.
|
||||
# Otherwise the recv-side load mask floors token_len to
|
||||
# (num_tokens-1)//block_size, the tail partial chunk is dropped, and -- if
|
||||
# the local cache covers the aligned prefix -- key_list ends up empty
|
||||
# (ZeroDivisionError in the recv thread's `tp_rank % len(key_list)`).
|
||||
scheduler = _make_bare_scheduler()
|
||||
scheduler.load_async = True
|
||||
scheduler.client = _StubLookupClient(hit_tokens=48) # full hit on 48-token prompt
|
||||
|
||||
request = SimpleNamespace(
|
||||
request_id="req-0",
|
||||
num_tokens=48,
|
||||
block_hashes=[b"h0", b"h1", b"h2"],
|
||||
)
|
||||
|
||||
need_to_allocate, load_async = scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens=16
|
||||
)
|
||||
|
||||
# 47 // 16 * 16 == 32 tokens left in external store after reserving the
|
||||
# sub-block tail for sampling. 32 - 16 (local) == 16 to load.
|
||||
assert need_to_allocate == 16
|
||||
assert load_async is True
|
||||
load_spec = scheduler.load_specs["req-0"]
|
||||
assert load_spec.vllm_cached_tokens == 16
|
||||
assert load_spec.kvpool_cached_tokens == 32
|
||||
assert load_spec.kvpool_cached_tokens % 16 == 0
|
||||
|
||||
|
||||
def test_full_external_hit_with_full_local_hit_skips_load():
|
||||
# When local prefix cache already covers the block-aligned external hit,
|
||||
# there is nothing for the connector to load. The pre-fix behavior would
|
||||
# have scheduled a 15-token load that the recv thread couldn't translate
|
||||
# into any block-aligned key.
|
||||
scheduler = _make_bare_scheduler()
|
||||
scheduler.load_async = True
|
||||
scheduler.client = _StubLookupClient(hit_tokens=48)
|
||||
|
||||
request = SimpleNamespace(
|
||||
request_id="req-0",
|
||||
num_tokens=48,
|
||||
block_hashes=[b"h0", b"h1", b"h2"],
|
||||
)
|
||||
|
||||
need_to_allocate, load_async = scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens=32
|
||||
)
|
||||
|
||||
assert need_to_allocate == 0
|
||||
assert load_async is False
|
||||
assert "req-0" not in scheduler.load_specs
|
||||
|
||||
@@ -68,12 +68,6 @@ class MooncakeStoreScheduler:
|
||||
kv_cache_config, vllm_config
|
||||
)
|
||||
|
||||
self._discard_partial_chunks = (
|
||||
vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"discard_partial_chunks", True
|
||||
)
|
||||
)
|
||||
|
||||
# Per-request state
|
||||
self.load_specs: dict[str, LoadSpec] = {} # to be loaded
|
||||
self._request_trackers: dict[str, RequestTracker] = {} # scheduled new requests
|
||||
@@ -88,18 +82,19 @@ class MooncakeStoreScheduler:
|
||||
) -> tuple[int, bool]:
|
||||
"""Check for external KV cache hit."""
|
||||
# Look up against the full prefill range, not just the prompt.
|
||||
if self._discard_partial_chunks:
|
||||
token_len = request.num_tokens // self._block_size * self._block_size
|
||||
else:
|
||||
token_len = request.num_tokens
|
||||
|
||||
token_len = request.num_tokens // self._block_size * self._block_size
|
||||
if token_len < self._block_size:
|
||||
return 0, False
|
||||
|
||||
num_external_hit_tokens = self.client.lookup(token_len, request.block_hashes)
|
||||
|
||||
if num_external_hit_tokens == request.num_tokens:
|
||||
num_external_hit_tokens -= 1
|
||||
# Leave a sub-block tail uncomputed for sampling, on a block
|
||||
# boundary so the recv-side load mask covers every yielded chunk.
|
||||
num_external_hit_tokens = max(
|
||||
0,
|
||||
(request.num_tokens - 1) // self._block_size * self._block_size,
|
||||
)
|
||||
|
||||
if num_external_hit_tokens < num_computed_tokens:
|
||||
need_to_allocate = 0
|
||||
@@ -214,9 +209,7 @@ class MooncakeStoreScheduler:
|
||||
self._request_trackers[request.req_id] = request_tracker
|
||||
|
||||
last_chunk_tokens_num = (
|
||||
(len(prefill_tokens) // self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks
|
||||
else len(prefill_tokens)
|
||||
len(prefill_tokens) // self._block_size * self._block_size
|
||||
)
|
||||
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
@@ -226,7 +219,6 @@ class MooncakeStoreScheduler:
|
||||
skip_save=force_skip_save,
|
||||
block_hashes=request_real.block_hashes,
|
||||
is_last_chunk=(request_tracker.token_len >= last_chunk_tokens_num),
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
original_block_size=self.original_block_size,
|
||||
)
|
||||
if req_meta is not None:
|
||||
@@ -269,9 +261,7 @@ class MooncakeStoreScheduler:
|
||||
self._request_trackers[req_id] = request_tracker
|
||||
|
||||
last_chunk_tokens_num = (
|
||||
(len(prefill_tokens) // self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks
|
||||
else len(prefill_tokens)
|
||||
len(prefill_tokens) // self._block_size * self._block_size
|
||||
)
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
@@ -282,7 +272,6 @@ class MooncakeStoreScheduler:
|
||||
is_last_chunk=(
|
||||
request_tracker.token_len >= last_chunk_tokens_num
|
||||
),
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
original_block_size=self.original_block_size,
|
||||
)
|
||||
else:
|
||||
@@ -310,9 +299,7 @@ class MooncakeStoreScheduler:
|
||||
request_tracker.update(new_block_ids)
|
||||
|
||||
last_chunk_tokens_num = (
|
||||
(prefill_end // self._block_size * self._block_size)
|
||||
if self._discard_partial_chunks
|
||||
else prefill_end
|
||||
prefill_end // self._block_size * self._block_size
|
||||
)
|
||||
req_meta = ReqMeta.from_request_tracker(
|
||||
request_tracker,
|
||||
@@ -323,7 +310,6 @@ class MooncakeStoreScheduler:
|
||||
is_last_chunk=(
|
||||
request_tracker.token_len >= last_chunk_tokens_num
|
||||
),
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
original_block_size=self.original_block_size,
|
||||
)
|
||||
|
||||
@@ -341,10 +327,6 @@ class MooncakeStoreScheduler:
|
||||
if not load_spec:
|
||||
continue
|
||||
num_tokens_to_compute = load_spec.kvpool_cached_tokens
|
||||
if (num_tokens_to_compute % self._block_size != 0) and (
|
||||
num_tokens_to_compute == unfinished_req.num_tokens - 1
|
||||
):
|
||||
num_tokens_to_compute = num_tokens_to_compute + 1
|
||||
request_tracker = RequestTracker(
|
||||
req_id=request_id,
|
||||
token_len=num_tokens_to_compute,
|
||||
@@ -358,7 +340,6 @@ class MooncakeStoreScheduler:
|
||||
load_spec=load_spec,
|
||||
skip_save=None,
|
||||
block_hashes=unfinished_req.block_hashes,
|
||||
discard_partial_chunks=self._discard_partial_chunks,
|
||||
)
|
||||
if req_meta is not None:
|
||||
meta.add_request(req_meta)
|
||||
|
||||
@@ -1195,14 +1195,7 @@ class MooncakeStoreWorker:
|
||||
if load_spec is None or not load_spec.can_load:
|
||||
continue
|
||||
|
||||
token_len = request.token_len_chunk
|
||||
if (load_spec.kvpool_cached_tokens % self.block_size != 0) and (
|
||||
load_spec.kvpool_cached_tokens == token_len - 1
|
||||
):
|
||||
token_len = load_spec.kvpool_cached_tokens + 1
|
||||
else:
|
||||
token_len = load_spec.kvpool_cached_tokens
|
||||
load_spec.token_len = token_len
|
||||
load_spec.token_len = load_spec.kvpool_cached_tokens
|
||||
|
||||
assert self.kv_recv_thread is not None
|
||||
self.kv_recv_thread.add_request(request)
|
||||
|
||||
Reference in New Issue
Block a user