[KV Offload] Add on_schedule_end() hook to separate step lifecycle from event draining (#44206)

Signed-off-by: Ronen Schaffer <ronen.schaffer@ibm.com>
This commit is contained in:
Ronen Schaffer
2026-06-02 13:42:52 +03:00
committed by GitHub
parent 689b0eeb9e
commit 2a2b5ca791
5 changed files with 55 additions and 40 deletions
+16 -12
View File
@@ -123,6 +123,11 @@ class TestTieringOffloadingManager:
secondary_tiers=[self.secondary_tier1, self.secondary_tier2],
)
def _simulate_on_schedule_end(self):
"""Simulate end of scheduler step: lifecycle flush + drain events."""
self.manager.on_schedule_end()
list(self.manager.take_events())
def test_basic_store_to_primary(self, manager_setup):
"""Test basic store operation to primary tier."""
blocks = to_keys(range(3))
@@ -185,10 +190,10 @@ class TestTieringOffloadingManager:
assert block.ref_cnt == 2
# End of step 1: _maybe_process_finished_jobs() was already called by
# prepare_store() above (setting the per-step flag), so take_events()
# prepare_store() above (setting the per-step flag), so on_schedule_end()
# does NOT poll get_finished_jobs() again — cascade completions remain
# unprocessed until the next step.
list(self.manager.take_events())
self._simulate_on_schedule_end()
# ref_cnt still held: cascade jobs finished (sync tier) but haven't
# been polled yet because the per-step guard skipped the second call.
@@ -202,7 +207,7 @@ class TestTieringOffloadingManager:
# End of step 2: flag was reset, so _maybe_process_finished_jobs()
# runs and processes the cascade completions (complete_read → ref_cnt--)
list(self.manager.take_events())
self._simulate_on_schedule_end()
# After cascade completes, ref_cnt should be 0
for block_hash in blocks:
@@ -238,10 +243,10 @@ class TestTieringOffloadingManager:
assert result is None # Retry later (promotion initiated)
# End of step 1: flushes deferred submit_load() calls
list(self.manager.take_events())
self._simulate_on_schedule_end()
# End of step 2: processes the completed promotion jobs
list(self.manager.take_events())
self._simulate_on_schedule_end()
# Now blocks should be in primary tier
assert count_hits(self.primary_tier, blocks) == 3
@@ -271,7 +276,7 @@ class TestTieringOffloadingManager:
self.manager.complete_store(blocks, _CTX, success=True)
# End of step: release ref_cnt from cascade
list(self.manager.take_events())
self._simulate_on_schedule_end()
# Now try to store 2 more blocks (should trigger eviction)
more_blocks = to_keys(range(5, 7))
@@ -289,7 +294,7 @@ class TestTieringOffloadingManager:
# Store blocks
self.manager.prepare_store(blocks, _CTX)
self.manager.complete_store(blocks, _CTX, success=True)
list(self.manager.take_events())
self._simulate_on_schedule_end()
self.secondary_tier1.touch = MagicMock(wraps=self.secondary_tier1.touch)
self.secondary_tier2.touch = MagicMock(wraps=self.secondary_tier2.touch)
@@ -328,7 +333,7 @@ class TestTieringOffloadingManager:
self.secondary_tier2.submit_store.assert_not_called()
def test_lookup_batches_submit_load_per_request(self, manager_setup):
"""lookup() defers submit_load until take_events(), one call per request.
"""lookup() defers submit_load until on_schedule_end(), one per request.
Blocks from different requests each get their own submit_load call, each
carrying the correct req_context.
@@ -354,7 +359,7 @@ class TestTieringOffloadingManager:
self.secondary_tier1.submit_load.assert_not_called()
# simulate end of step
list(self.manager.take_events())
self._simulate_on_schedule_end()
assert self.secondary_tier1.submit_load.call_count == 2
calls = self.secondary_tier1.submit_load.call_args_list
@@ -389,7 +394,7 @@ class TestTieringOffloadingManager:
assert result_a is None
assert result_b is None
list(self.manager.take_events())
self._simulate_on_schedule_end()
# Only one submit_load call despite two lookups
self.secondary_tier1.submit_load.assert_called_once()
@@ -449,8 +454,7 @@ class TestTieringOffloadingManager:
assert result is not None
self.manager.complete_store(existing_blocks, _CTX, success=True)
# Drain cascade completions
list(self.manager.take_events())
list(self.manager.take_events())
self._simulate_on_schedule_end()
# Make tier1 request-level, tier2 stays block-level
self.secondary_tier1.on_new_request = (
@@ -897,6 +897,7 @@ class OffloadingConnectorScheduler:
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
self._update_req_states(scheduler_output)
self.manager.on_schedule_end()
# Flush jobs for preempted requests.
for req_id in scheduler_output.preempted_req_ids or ():
+8
View File
@@ -257,6 +257,14 @@ class OffloadingManager(ABC):
"""
return ()
def on_schedule_end(self) -> None:
"""Called once at the end of each scheduler step.
Managers may override this to flush deferred work accumulated
during the step (e.g., batched promotions).
"""
return
def reset_cache(self) -> None:
"""Evict all tracked blocks and reset internal state."""
return
+8
View File
@@ -185,6 +185,14 @@ class SecondaryTierManager(ABC):
"""
return
def on_schedule_end(self) -> None:
"""Called once at the end of each scheduler step.
Secondary tiers may override this for per-step cleanup or
deferred work submission.
"""
return
def shutdown(self) -> None:
"""Release resources held by this tier (threads, connections, etc.)."""
return
+22 -28
View File
@@ -152,7 +152,7 @@ class TieringOffloadingManager(OffloadingManager):
self._transfer_jobs: dict[JobId, JobMetadata] = {}
# Pending promotion requests accumulated during lookup() calls; flushed
# as one batched submit_load() per (tier, request) in take_events().
# as one batched submit_load() per (tier, request) in on_schedule_end().
# Outer key: tier. Inner key: req_context.req_id — the same ReqContext
# object is reused for all block lookups of a given request per engine step.
self._pending_load_submissions: dict[
@@ -160,7 +160,7 @@ class TieringOffloadingManager(OffloadingManager):
] = {}
# Gate for once-per-step execution of _maybe_process_finished_jobs().
# Reset at the end of each step in take_events().
# Reset at the end of each step in on_schedule_end().
self._processed_jobs_this_step: bool = False
# Per-request set of secondary tiers that requested REQUEST_LEVEL
@@ -182,7 +182,7 @@ class TieringOffloadingManager(OffloadingManager):
Guarded by _processed_jobs_this_step: the first call in an engine step
does the actual polling; subsequent calls are no-ops. The flag is reset
in take_events() at the end of each step.
in on_schedule_end() at the end of each step.
"""
if self._processed_jobs_this_step:
return
@@ -304,8 +304,8 @@ class TieringOffloadingManager(OffloadingManager):
store_spec = primary_write_result.store_spec
assert isinstance(store_spec, CPULoadStoreSpec)
# Defer submit_load to take_events(). Group by (tier, request) so each
# request's blocks are submitted as one batched job per tier.
# Defer submit_load to on_schedule_end(). Group by (tier, request) so
# each request's blocks are submitted as one batched job per tier.
tier_pending = self._pending_load_submissions.setdefault(tier, {})
ctx_id = req_context.req_id
if ctx_id not in tier_pending:
@@ -320,8 +320,8 @@ class TieringOffloadingManager(OffloadingManager):
def _flush_pending_promotions(self) -> None:
"""Submit one batched submit_load() per (tier, request).
Called from take_events() at the end of each engine step, flushing
all promotion requests deferred during lookup().
Called from on_schedule_end() at the end of each scheduler step,
flushing all promotion requests deferred during lookup().
"""
if not self._pending_load_submissions:
return
@@ -564,32 +564,26 @@ class TieringOffloadingManager(OffloadingManager):
self._request_level_tiers.pop(req_context.req_id, None)
@override
def take_events(self) -> Iterable[OffloadingEvent]:
"""
End-of-step hook: flush deferred work, yield events, reset per-step state.
def on_schedule_end(self) -> None:
"""End-of-schedule hook: process finished jobs, flush deferred
promotions, and reset the per-step gate.
Called once per engine step from Scheduler.update_from_output() →
connector.take_events(). Ensures _maybe_process_finished_jobs() has run
at least once this step, flushes pending promotions, yields collected
events, and resets the per-step flag.
Called once per scheduler step from
OffloadingConnectorScheduler.build_connector_meta().
"""
self._maybe_process_finished_jobs()
self._processed_jobs_this_step = False
self._flush_pending_promotions()
for tier in self.secondary_tiers:
tier.on_schedule_end()
@override
def take_events(self) -> Iterable[OffloadingEvent]:
"""Yield offloading events collected since the last call.
Yields:
New OffloadingEvents collected since the last call.
"""
# TODO: Move _flush_pending_promotions() to a dedicated end_of_batch()
# hook once one exists. For now, take_events() serves as the flush
# point under the assumption that it is called at the end of each
# engine step (Scheduler.update_from_output() → connector.take_events()).
# When the dedicated hook is added, update tests that rely on
# take_events() to signal end of step.
self._maybe_process_finished_jobs()
self._flush_pending_promotions()
# Reset the per-step gate so next step's first call does real work.
self._processed_jobs_this_step = False
if self.events is not None:
yield from self.events
self.events.clear()