mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[KV Offload] Add on_schedule_end() hook to separate step lifecycle from event draining (#44206)
Signed-off-by: Ronen Schaffer <ronen.schaffer@ibm.com>
This commit is contained in:
@@ -123,6 +123,11 @@ class TestTieringOffloadingManager:
|
||||
secondary_tiers=[self.secondary_tier1, self.secondary_tier2],
|
||||
)
|
||||
|
||||
def _simulate_on_schedule_end(self):
|
||||
"""Simulate end of scheduler step: lifecycle flush + drain events."""
|
||||
self.manager.on_schedule_end()
|
||||
list(self.manager.take_events())
|
||||
|
||||
def test_basic_store_to_primary(self, manager_setup):
|
||||
"""Test basic store operation to primary tier."""
|
||||
blocks = to_keys(range(3))
|
||||
@@ -185,10 +190,10 @@ class TestTieringOffloadingManager:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# End of step 1: _maybe_process_finished_jobs() was already called by
|
||||
# prepare_store() above (setting the per-step flag), so take_events()
|
||||
# prepare_store() above (setting the per-step flag), so on_schedule_end()
|
||||
# does NOT poll get_finished_jobs() again — cascade completions remain
|
||||
# unprocessed until the next step.
|
||||
list(self.manager.take_events())
|
||||
self._simulate_on_schedule_end()
|
||||
|
||||
# ref_cnt still held: cascade jobs finished (sync tier) but haven't
|
||||
# been polled yet because the per-step guard skipped the second call.
|
||||
@@ -202,7 +207,7 @@ class TestTieringOffloadingManager:
|
||||
|
||||
# End of step 2: flag was reset, so _maybe_process_finished_jobs()
|
||||
# runs and processes the cascade completions (complete_read → ref_cnt--)
|
||||
list(self.manager.take_events())
|
||||
self._simulate_on_schedule_end()
|
||||
|
||||
# After cascade completes, ref_cnt should be 0
|
||||
for block_hash in blocks:
|
||||
@@ -238,10 +243,10 @@ class TestTieringOffloadingManager:
|
||||
assert result is None # Retry later (promotion initiated)
|
||||
|
||||
# End of step 1: flushes deferred submit_load() calls
|
||||
list(self.manager.take_events())
|
||||
self._simulate_on_schedule_end()
|
||||
|
||||
# End of step 2: processes the completed promotion jobs
|
||||
list(self.manager.take_events())
|
||||
self._simulate_on_schedule_end()
|
||||
|
||||
# Now blocks should be in primary tier
|
||||
assert count_hits(self.primary_tier, blocks) == 3
|
||||
@@ -271,7 +276,7 @@ class TestTieringOffloadingManager:
|
||||
self.manager.complete_store(blocks, _CTX, success=True)
|
||||
|
||||
# End of step: release ref_cnt from cascade
|
||||
list(self.manager.take_events())
|
||||
self._simulate_on_schedule_end()
|
||||
|
||||
# Now try to store 2 more blocks (should trigger eviction)
|
||||
more_blocks = to_keys(range(5, 7))
|
||||
@@ -289,7 +294,7 @@ class TestTieringOffloadingManager:
|
||||
# Store blocks
|
||||
self.manager.prepare_store(blocks, _CTX)
|
||||
self.manager.complete_store(blocks, _CTX, success=True)
|
||||
list(self.manager.take_events())
|
||||
self._simulate_on_schedule_end()
|
||||
|
||||
self.secondary_tier1.touch = MagicMock(wraps=self.secondary_tier1.touch)
|
||||
self.secondary_tier2.touch = MagicMock(wraps=self.secondary_tier2.touch)
|
||||
@@ -328,7 +333,7 @@ class TestTieringOffloadingManager:
|
||||
self.secondary_tier2.submit_store.assert_not_called()
|
||||
|
||||
def test_lookup_batches_submit_load_per_request(self, manager_setup):
|
||||
"""lookup() defers submit_load until take_events(), one call per request.
|
||||
"""lookup() defers submit_load until on_schedule_end(), one per request.
|
||||
|
||||
Blocks from different requests each get their own submit_load call, each
|
||||
carrying the correct req_context.
|
||||
@@ -354,7 +359,7 @@ class TestTieringOffloadingManager:
|
||||
self.secondary_tier1.submit_load.assert_not_called()
|
||||
|
||||
# simulate end of step
|
||||
list(self.manager.take_events())
|
||||
self._simulate_on_schedule_end()
|
||||
|
||||
assert self.secondary_tier1.submit_load.call_count == 2
|
||||
calls = self.secondary_tier1.submit_load.call_args_list
|
||||
@@ -389,7 +394,7 @@ class TestTieringOffloadingManager:
|
||||
assert result_a is None
|
||||
assert result_b is None
|
||||
|
||||
list(self.manager.take_events())
|
||||
self._simulate_on_schedule_end()
|
||||
|
||||
# Only one submit_load call despite two lookups
|
||||
self.secondary_tier1.submit_load.assert_called_once()
|
||||
@@ -449,8 +454,7 @@ class TestTieringOffloadingManager:
|
||||
assert result is not None
|
||||
self.manager.complete_store(existing_blocks, _CTX, success=True)
|
||||
# Drain cascade completions
|
||||
list(self.manager.take_events())
|
||||
list(self.manager.take_events())
|
||||
self._simulate_on_schedule_end()
|
||||
|
||||
# Make tier1 request-level, tier2 stays block-level
|
||||
self.secondary_tier1.on_new_request = (
|
||||
|
||||
@@ -897,6 +897,7 @@ class OffloadingConnectorScheduler:
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
self._update_req_states(scheduler_output)
|
||||
self.manager.on_schedule_end()
|
||||
|
||||
# Flush jobs for preempted requests.
|
||||
for req_id in scheduler_output.preempted_req_ids or ():
|
||||
|
||||
@@ -257,6 +257,14 @@ class OffloadingManager(ABC):
|
||||
"""
|
||||
return ()
|
||||
|
||||
def on_schedule_end(self) -> None:
|
||||
"""Called once at the end of each scheduler step.
|
||||
|
||||
Managers may override this to flush deferred work accumulated
|
||||
during the step (e.g., batched promotions).
|
||||
"""
|
||||
return
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
"""Evict all tracked blocks and reset internal state."""
|
||||
return
|
||||
|
||||
@@ -185,6 +185,14 @@ class SecondaryTierManager(ABC):
|
||||
"""
|
||||
return
|
||||
|
||||
def on_schedule_end(self) -> None:
|
||||
"""Called once at the end of each scheduler step.
|
||||
|
||||
Secondary tiers may override this for per-step cleanup or
|
||||
deferred work submission.
|
||||
"""
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Release resources held by this tier (threads, connections, etc.)."""
|
||||
return
|
||||
|
||||
@@ -152,7 +152,7 @@ class TieringOffloadingManager(OffloadingManager):
|
||||
self._transfer_jobs: dict[JobId, JobMetadata] = {}
|
||||
|
||||
# Pending promotion requests accumulated during lookup() calls; flushed
|
||||
# as one batched submit_load() per (tier, request) in take_events().
|
||||
# as one batched submit_load() per (tier, request) in on_schedule_end().
|
||||
# Outer key: tier. Inner key: req_context.req_id — the same ReqContext
|
||||
# object is reused for all block lookups of a given request per engine step.
|
||||
self._pending_load_submissions: dict[
|
||||
@@ -160,7 +160,7 @@ class TieringOffloadingManager(OffloadingManager):
|
||||
] = {}
|
||||
|
||||
# Gate for once-per-step execution of _maybe_process_finished_jobs().
|
||||
# Reset at the end of each step in take_events().
|
||||
# Reset at the end of each step in on_schedule_end().
|
||||
self._processed_jobs_this_step: bool = False
|
||||
|
||||
# Per-request set of secondary tiers that requested REQUEST_LEVEL
|
||||
@@ -182,7 +182,7 @@ class TieringOffloadingManager(OffloadingManager):
|
||||
|
||||
Guarded by _processed_jobs_this_step: the first call in an engine step
|
||||
does the actual polling; subsequent calls are no-ops. The flag is reset
|
||||
in take_events() at the end of each step.
|
||||
in on_schedule_end() at the end of each step.
|
||||
"""
|
||||
if self._processed_jobs_this_step:
|
||||
return
|
||||
@@ -304,8 +304,8 @@ class TieringOffloadingManager(OffloadingManager):
|
||||
|
||||
store_spec = primary_write_result.store_spec
|
||||
assert isinstance(store_spec, CPULoadStoreSpec)
|
||||
# Defer submit_load to take_events(). Group by (tier, request) so each
|
||||
# request's blocks are submitted as one batched job per tier.
|
||||
# Defer submit_load to on_schedule_end(). Group by (tier, request) so
|
||||
# each request's blocks are submitted as one batched job per tier.
|
||||
tier_pending = self._pending_load_submissions.setdefault(tier, {})
|
||||
ctx_id = req_context.req_id
|
||||
if ctx_id not in tier_pending:
|
||||
@@ -320,8 +320,8 @@ class TieringOffloadingManager(OffloadingManager):
|
||||
def _flush_pending_promotions(self) -> None:
|
||||
"""Submit one batched submit_load() per (tier, request).
|
||||
|
||||
Called from take_events() at the end of each engine step, flushing
|
||||
all promotion requests deferred during lookup().
|
||||
Called from on_schedule_end() at the end of each scheduler step,
|
||||
flushing all promotion requests deferred during lookup().
|
||||
"""
|
||||
if not self._pending_load_submissions:
|
||||
return
|
||||
@@ -564,32 +564,26 @@ class TieringOffloadingManager(OffloadingManager):
|
||||
self._request_level_tiers.pop(req_context.req_id, None)
|
||||
|
||||
@override
|
||||
def take_events(self) -> Iterable[OffloadingEvent]:
|
||||
"""
|
||||
End-of-step hook: flush deferred work, yield events, reset per-step state.
|
||||
def on_schedule_end(self) -> None:
|
||||
"""End-of-schedule hook: process finished jobs, flush deferred
|
||||
promotions, and reset the per-step gate.
|
||||
|
||||
Called once per engine step from Scheduler.update_from_output() →
|
||||
connector.take_events(). Ensures _maybe_process_finished_jobs() has run
|
||||
at least once this step, flushes pending promotions, yields collected
|
||||
events, and resets the per-step flag.
|
||||
Called once per scheduler step from
|
||||
OffloadingConnectorScheduler.build_connector_meta().
|
||||
"""
|
||||
self._maybe_process_finished_jobs()
|
||||
self._processed_jobs_this_step = False
|
||||
self._flush_pending_promotions()
|
||||
for tier in self.secondary_tiers:
|
||||
tier.on_schedule_end()
|
||||
|
||||
@override
|
||||
def take_events(self) -> Iterable[OffloadingEvent]:
|
||||
"""Yield offloading events collected since the last call.
|
||||
|
||||
Yields:
|
||||
New OffloadingEvents collected since the last call.
|
||||
"""
|
||||
# TODO: Move _flush_pending_promotions() to a dedicated end_of_batch()
|
||||
# hook once one exists. For now, take_events() serves as the flush
|
||||
# point under the assumption that it is called at the end of each
|
||||
# engine step (Scheduler.update_from_output() → connector.take_events()).
|
||||
# When the dedicated hook is added, update tests that rely on
|
||||
# take_events() to signal end of step.
|
||||
|
||||
self._maybe_process_finished_jobs()
|
||||
|
||||
self._flush_pending_promotions()
|
||||
|
||||
# Reset the per-step gate so next step's first call does real work.
|
||||
self._processed_jobs_this_step = False
|
||||
|
||||
if self.events is not None:
|
||||
yield from self.events
|
||||
self.events.clear()
|
||||
|
||||
Reference in New Issue
Block a user