[KV Connector] Keep MooncakeStore full hits block-aligned (#43494)

Signed-off-by: Dao Le <daole@inferact.ai>
Signed-off-by: Dao Le <Dao007forever@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Dao007forever
2026-05-23 23:15:03 -07:00
committed by GitHub
parent 33d7cbe02c
commit 0902d8e62f
4 changed files with 74 additions and 39 deletions
@@ -204,7 +204,6 @@ the vLLM JSON config.
- `load_async` (bool): Enable asynchronous loading for better compute-I/O overlap. Default: `true`.
- `enable_cross_layers_blocks` (bool): Enable cross-layer block packing for reduced store operations. Default: `false`.
- `discard_partial_chunks` (bool): Discard partial block chunks during store. Default: `true`.
- `lookup_rpc_port` (int): Custom port for the ZMQ lookup RPC socket. Default: `0`.
## Notes
@@ -18,7 +18,6 @@ def _make_bare_scheduler() -> MooncakeStoreScheduler:
scheduler.kv_role = "kv_both"
scheduler.original_block_size = 16
scheduler._block_size = 16
scheduler._discard_partial_chunks = True
scheduler.load_specs = {}
scheduler._preempted_req_ids = set()
scheduler._unfinished_request_ids = {"req-0"}
@@ -344,3 +343,66 @@ def test_from_request_tracker_no_load_saves_normally():
assert req_meta.can_save is True
assert req_meta.load_spec is None
assert tracker.num_saved_tokens == 48
class _StubLookupClient:
def __init__(self, hit_tokens: int) -> None:
self._hit_tokens = hit_tokens
def lookup(self, token_len: int, block_hashes: list[bytes]) -> int:
return self._hit_tokens
def test_full_external_hit_keeps_kvpool_cached_tokens_block_aligned():
# When the external store hits the entire prompt, scheduler must leave at
# least one token uncomputed for sampling but stay on a block boundary.
# Otherwise the recv-side load mask floors token_len to
# (num_tokens-1)//block_size, the tail partial chunk is dropped, and -- if
# the local cache covers the aligned prefix -- key_list ends up empty
# (ZeroDivisionError in the recv thread's `tp_rank % len(key_list)`).
scheduler = _make_bare_scheduler()
scheduler.load_async = True
scheduler.client = _StubLookupClient(hit_tokens=48) # full hit on 48-token prompt
request = SimpleNamespace(
request_id="req-0",
num_tokens=48,
block_hashes=[b"h0", b"h1", b"h2"],
)
need_to_allocate, load_async = scheduler.get_num_new_matched_tokens(
request, num_computed_tokens=16
)
# 47 // 16 * 16 == 32 tokens left in external store after reserving the
# sub-block tail for sampling. 32 - 16 (local) == 16 to load.
assert need_to_allocate == 16
assert load_async is True
load_spec = scheduler.load_specs["req-0"]
assert load_spec.vllm_cached_tokens == 16
assert load_spec.kvpool_cached_tokens == 32
assert load_spec.kvpool_cached_tokens % 16 == 0
def test_full_external_hit_with_full_local_hit_skips_load():
# When local prefix cache already covers the block-aligned external hit,
# there is nothing for the connector to load. The pre-fix behavior would
# have scheduled a 15-token load that the recv thread couldn't translate
# into any block-aligned key.
scheduler = _make_bare_scheduler()
scheduler.load_async = True
scheduler.client = _StubLookupClient(hit_tokens=48)
request = SimpleNamespace(
request_id="req-0",
num_tokens=48,
block_hashes=[b"h0", b"h1", b"h2"],
)
need_to_allocate, load_async = scheduler.get_num_new_matched_tokens(
request, num_computed_tokens=32
)
assert need_to_allocate == 0
assert load_async is False
assert "req-0" not in scheduler.load_specs
@@ -68,12 +68,6 @@ class MooncakeStoreScheduler:
kv_cache_config, vllm_config
)
self._discard_partial_chunks = (
vllm_config.kv_transfer_config.get_from_extra_config(
"discard_partial_chunks", True
)
)
# Per-request state
self.load_specs: dict[str, LoadSpec] = {} # to be loaded
self._request_trackers: dict[str, RequestTracker] = {} # scheduled new requests
@@ -88,18 +82,19 @@ class MooncakeStoreScheduler:
) -> tuple[int, bool]:
"""Check for external KV cache hit."""
# Look up against the full prefill range, not just the prompt.
if self._discard_partial_chunks:
token_len = request.num_tokens // self._block_size * self._block_size
else:
token_len = request.num_tokens
token_len = request.num_tokens // self._block_size * self._block_size
if token_len < self._block_size:
return 0, False
num_external_hit_tokens = self.client.lookup(token_len, request.block_hashes)
if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1
# Leave a sub-block tail uncomputed for sampling, on a block
# boundary so the recv-side load mask covers every yielded chunk.
num_external_hit_tokens = max(
0,
(request.num_tokens - 1) // self._block_size * self._block_size,
)
if num_external_hit_tokens < num_computed_tokens:
need_to_allocate = 0
@@ -214,9 +209,7 @@ class MooncakeStoreScheduler:
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = (
(len(prefill_tokens) // self._block_size * self._block_size)
if self._discard_partial_chunks
else len(prefill_tokens)
len(prefill_tokens) // self._block_size * self._block_size
)
req_meta = ReqMeta.from_request_tracker(
@@ -226,7 +219,6 @@ class MooncakeStoreScheduler:
skip_save=force_skip_save,
block_hashes=request_real.block_hashes,
is_last_chunk=(request_tracker.token_len >= last_chunk_tokens_num),
discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
)
if req_meta is not None:
@@ -269,9 +261,7 @@ class MooncakeStoreScheduler:
self._request_trackers[req_id] = request_tracker
last_chunk_tokens_num = (
(len(prefill_tokens) // self._block_size * self._block_size)
if self._discard_partial_chunks
else len(prefill_tokens)
len(prefill_tokens) // self._block_size * self._block_size
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
@@ -282,7 +272,6 @@ class MooncakeStoreScheduler:
is_last_chunk=(
request_tracker.token_len >= last_chunk_tokens_num
),
discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
)
else:
@@ -310,9 +299,7 @@ class MooncakeStoreScheduler:
request_tracker.update(new_block_ids)
last_chunk_tokens_num = (
(prefill_end // self._block_size * self._block_size)
if self._discard_partial_chunks
else prefill_end
prefill_end // self._block_size * self._block_size
)
req_meta = ReqMeta.from_request_tracker(
request_tracker,
@@ -323,7 +310,6 @@ class MooncakeStoreScheduler:
is_last_chunk=(
request_tracker.token_len >= last_chunk_tokens_num
),
discard_partial_chunks=self._discard_partial_chunks,
original_block_size=self.original_block_size,
)
@@ -341,10 +327,6 @@ class MooncakeStoreScheduler:
if not load_spec:
continue
num_tokens_to_compute = load_spec.kvpool_cached_tokens
if (num_tokens_to_compute % self._block_size != 0) and (
num_tokens_to_compute == unfinished_req.num_tokens - 1
):
num_tokens_to_compute = num_tokens_to_compute + 1
request_tracker = RequestTracker(
req_id=request_id,
token_len=num_tokens_to_compute,
@@ -358,7 +340,6 @@ class MooncakeStoreScheduler:
load_spec=load_spec,
skip_save=None,
block_hashes=unfinished_req.block_hashes,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
@@ -1195,14 +1195,7 @@ class MooncakeStoreWorker:
if load_spec is None or not load_spec.can_load:
continue
token_len = request.token_len_chunk
if (load_spec.kvpool_cached_tokens % self.block_size != 0) and (
load_spec.kvpool_cached_tokens == token_len - 1
):
token_len = load_spec.kvpool_cached_tokens + 1
else:
token_len = load_spec.kvpool_cached_tokens
load_spec.token_len = token_len
load_spec.token_len = load_spec.kvpool_cached_tokens
assert self.kv_recv_thread is not None
self.kv_recv_thread.add_request(request)