Add objectstore as a secondary tier to multi-tier kv cache offloading (#41968)

Signed-off-by: Effi Ofer <effi.ofer@gmail.com>
This commit is contained in:
Effi Ofer
2026-06-05 18:05:41 +03:00
committed by GitHub
parent 7f003a1285
commit 6a894574bf
6 changed files with 659 additions and 2 deletions
@@ -0,0 +1,365 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Mock-based unit tests for ObjectStoreSecondaryTierManager.
These tests replace the NIXL backend with an in-memory mock so they run
without S3 credentials or a live object store. They verify the manager's
state machine: job submission, transfer completion polling, and lookup.
"""
import uuid
from collections.abc import Callable
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import numpy as np
import torch
from vllm.v1.kv_offload.base import OffloadKey, ReqContext, make_offload_key
from vllm.v1.kv_offload.tiering.base import JobMetadata, JobResult
from vllm.v1.kv_offload.tiering.obj.manager import ObjectStoreSecondaryTierManager
# ---------------------------------------------------------------------------
# Shared stubs
# ---------------------------------------------------------------------------
def _make_vllm_config():
return SimpleNamespace(
model_config=SimpleNamespace(model="test/model"),
cache_config=SimpleNamespace(block_size=16, cache_dtype="float16"),
parallel_config=SimpleNamespace(
tensor_parallel_size=1,
pipeline_parallel_size=1,
prefill_context_parallel_size=1,
decode_context_parallel_size=1,
rank=0,
),
)
_OFFLOADING_SPEC = SimpleNamespace(
vllm_config=_make_vllm_config(),
kv_cache_config=SimpleNamespace(kv_cache_groups=[]),
)
_STORE_CONFIG = {
"bucket": "mock-bucket",
"endpoint_override": "mock:9000",
"access_key": "mock-access",
"secret_key": "mock-secret",
}
_BLOCK_ELEMENTS = 256
_DTYPE = torch.float32
_RUN_PREFIX = f"test/{uuid.uuid4().hex[:8]}"
_CTX = ReqContext(req_id="test-req")
def key(n: int) -> OffloadKey:
return make_offload_key(n.to_bytes(8, "big"), 0)
def make_job(
job_id: int,
keys: list[OffloadKey],
block_ids: list[int] | None = None,
) -> JobMetadata:
if block_ids is None:
block_ids = list(range(len(keys)))
return JobMetadata(
job_id=job_id,
keys=keys,
block_ids=np.array(block_ids, dtype=np.int64),
is_promotion=False,
req_context=_CTX,
)
# ---------------------------------------------------------------------------
# Mock NIXL agent
# ---------------------------------------------------------------------------
class MockNixlAgent:
"""In-memory NIXL agent. Tracks stored object keys and simulates async
transfers: transfer() returns PROC, check_xfer_state() returns DONE and
commits the write to the in-memory key set.
The four methods overridden by tests (register_memory, make_prepped_xfer,
check_xfer_state, query_memory) are stored as Callable instance attributes
so mypy allows reassignment in tests.
"""
# Callable attributes — tests may reassign these on instances.
register_memory: Callable
make_prepped_xfer: Callable
check_xfer_state: Callable
query_memory: Callable
def __init__(self):
self._stored_obj_keys: set[str] = set()
# handle_id -> (op, [obj_keys])
self._pending: dict[int, tuple[str, list[str]]] = {}
self._handle_counter = 0
self._last_obj_keys: list[str] = []
# Bind default implementations as instance attributes.
self.register_memory = self._register_memory
self.make_prepped_xfer = self._make_prepped_xfer
self.check_xfer_state = self._check_xfer_state
self.query_memory = self._query_memory
def create_backend(self, backend_type, params):
pass
def _register_memory(self, descs, mem_type=None, backends=None):
mock = MagicMock()
mock.trim.return_value = MagicMock()
# Capture obj_keys from OBJ 4-tuples: (addr, len, dev_id, obj_key)
if mem_type == "OBJ" and descs:
self._last_obj_keys = [d[3] for d in descs if d[3]]
return mock
def deregister_memory(self, desc):
pass
def prep_xfer_dlist(self, agent_name, descs, mem_type=None, backends=None):
return MagicMock()
def _make_prepped_xfer(
self,
op,
local_handle,
local_indices,
remote_handle,
remote_indices,
notif_msg=b"",
backends=None,
skip_desc_merge=False,
):
handle = MagicMock()
handle._id = self._handle_counter
self._pending[self._handle_counter] = (op, list(self._last_obj_keys))
self._handle_counter += 1
return handle
def transfer(self, handle):
return "PROC"
def _check_xfer_state(self, handle):
entry = self._pending.pop(handle._id, None)
if entry:
op, obj_keys = entry
if op == "WRITE":
self._stored_obj_keys.update(obj_keys)
return "DONE"
def release_xfer_handle(self, handle):
pass
def release_dlist_handle(self, handle):
pass
def _query_memory(self, queries, mem_type, agent_name):
return [object() if q[3] in self._stored_obj_keys else None for q in queries]
# ---------------------------------------------------------------------------
# Fixture
# ---------------------------------------------------------------------------
def _make_tier(
num_blocks: int = 4,
) -> tuple[ObjectStoreSecondaryTierManager, MockNixlAgent]:
"""Create a tier backed by a fresh MockNixlAgent."""
mock_agent = MockNixlAgent()
tensor = torch.zeros((num_blocks, _BLOCK_ELEMENTS), dtype=_DTYPE)
view = memoryview(tensor.numpy())
with (
patch("vllm.v1.kv_offload.tiering.obj.manager.nixl_agent_config"),
patch(
"vllm.v1.kv_offload.tiering.obj.manager.nixl_agent",
return_value=mock_agent,
),
):
tier = ObjectStoreSecondaryTierManager(
offloading_spec=_OFFLOADING_SPEC,
primary_kv_view=view,
tier_type="obj",
store_config=_STORE_CONFIG,
prefix=_RUN_PREFIX,
)
return tier, mock_agent
def drain(
tier: ObjectStoreSecondaryTierManager, max_rounds: int = 20
) -> list[JobResult]:
"""Poll get_finished_jobs() until all in-flight jobs resolve."""
results: list[JobResult] = []
for _ in range(max_rounds):
results.extend(tier.get_finished_jobs())
if not tier._transfers:
break
return results
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestMockObjTierBasic:
def setup_method(self):
self.tier, self.agent = _make_tier(num_blocks=4)
def test_lookup_empty_tier(self):
assert self.tier.lookup(key(1), _CTX) is False
def test_store_and_lookup(self):
self.tier.submit_store(make_job(1, [key(1)], [0]))
results = drain(self.tier)
assert len(results) == 1
assert results[0].success
assert self.tier.lookup(key(1), _CTX) is True
def test_lookup_unrelated_key_returns_false(self):
self.tier.submit_store(make_job(1, [key(1)], [0]))
drain(self.tier)
assert self.tier.lookup(key(999), _CTX) is False
def test_store_then_load_roundtrip(self):
self.tier.submit_store(make_job(1, [key(1), key(2)], [0, 1]))
results = drain(self.tier)
assert results[0].success
self.tier.submit_load(make_job(2, [key(1), key(2)], [0, 1]))
results = drain(self.tier)
assert len(results) == 1
assert results[0].success
def test_multiple_jobs_tracked_independently(self):
self.tier.submit_store(make_job(1, [key(1)], [0]))
self.tier.submit_store(make_job(2, [key(2)], [1]))
results = drain(self.tier)
assert len(results) == 2
assert all(r.success for r in results)
def test_failed_transfer_reported(self):
self.agent.check_xfer_state = lambda h: "ERR"
self.tier.submit_store(make_job(1, [key(1)], [0]))
results = drain(self.tier)
assert len(results) == 1
assert not results[0].success
def test_pending_transfer_not_returned_until_done(self):
# First poll returns PROC; second poll returns DONE.
call_count = [0]
original = self.agent.check_xfer_state
def delayed(h):
call_count[0] += 1
return "PROC" if call_count[0] == 1 else original(h)
self.agent.check_xfer_state = delayed
self.tier.submit_store(make_job(1, [key(1)], [0]))
assert list(self.tier.get_finished_jobs()) == []
results = list(self.tier.get_finished_jobs())
assert len(results) == 1
assert results[0].success
class TestMockObjTierMultiBlock:
def test_store_multiple_blocks(self):
tier, _ = _make_tier(num_blocks=8)
keys = [key(i) for i in range(8)]
tier.submit_store(make_job(1, keys, list(range(8))))
results = drain(tier)
assert len(results) == 1
assert results[0].success
assert all(tier.lookup(k, _CTX) for k in keys)
def test_partial_block_lookup(self):
tier, _ = _make_tier(num_blocks=4)
tier.submit_store(make_job(1, [key(0), key(1)], [0, 1]))
drain(tier)
assert tier.lookup(key(0), _CTX) is True
assert tier.lookup(key(1), _CTX) is True
assert tier.lookup(key(2), _CTX) is False
class TestMockObjTierFailures:
def test_lookup_exception_returns_false(self):
tier, agent = _make_tier(num_blocks=4)
agent.query_memory = lambda *a, **k: (_ for _ in ()).throw(
RuntimeError("backend error")
)
assert tier.lookup(key(1), _CTX) is False
def test_submit_store_register_memory_failure_reported_in_get_finished(self):
tier, agent = _make_tier(num_blocks=4)
agent.register_memory = lambda *a, **k: None
tier.submit_store(make_job(1, [key(1)], [0]))
results = list(tier.get_finished_jobs())
assert len(results) == 1
assert results[0].job_id == 1
assert not results[0].success
def test_submit_load_register_memory_failure_reported_in_get_finished(self):
tier, agent = _make_tier(num_blocks=4)
agent.register_memory = lambda *a, **k: None
tier.submit_load(make_job(2, [key(1)], [0]))
results = list(tier.get_finished_jobs())
assert len(results) == 1
assert results[0].job_id == 2
assert not results[0].success
def test_submit_store_make_prepped_xfer_failure_reported_in_get_finished(self):
tier, agent = _make_tier(num_blocks=4)
agent.make_prepped_xfer = lambda *a, **k: None
tier.submit_store(make_job(3, [key(1)], [0]))
results = list(tier.get_finished_jobs())
assert len(results) == 1
assert results[0].job_id == 3
assert not results[0].success
def test_failure_and_success_both_returned_by_get_finished(self):
# One job fails at submission, another succeeds in flight.
tier, agent = _make_tier(num_blocks=4)
original_register = agent.register_memory
call_count = [0]
def register_once_fail(*a, **k):
call_count[0] += 1
return None if call_count[0] == 1 else original_register(*a, **k)
agent.register_memory = register_once_fail
tier.submit_store(make_job(1, [key(1)], [0])) # fails immediately
tier.submit_store(make_job(2, [key(2)], [1])) # succeeds
results = drain(tier)
assert len(results) == 2
by_id = {r.job_id: r for r in results}
assert not by_id[1].success
assert by_id[2].success
class TestMockObjTierShutdown:
def test_shutdown_clears_in_flight_transfers(self):
tier, agent = _make_tier(num_blocks=4)
# Keep transfer in flight by never completing it
agent.check_xfer_state = lambda h: "PROC"
tier.submit_store(make_job(1, [key(1)], [0]))
assert len(tier._transfers) == 1
tier.shutdown()
assert len(tier._transfers) == 0
assert tier._dram_prepped_handle is None
assert tier._primary_reg is None
def test_shutdown_idempotent(self):
tier, _ = _make_tier(num_blocks=4)
tier.shutdown()
tier.shutdown() # must not raise
+6
View File
@@ -63,3 +63,9 @@ SecondaryTierFactory.register_tier(
"vllm.v1.kv_offload.tiering.fs.manager",
"FileSystemTierManager",
)
SecondaryTierFactory.register_tier(
"obj",
"vllm.v1.kv_offload.tiering.obj.manager",
"ObjectStoreSecondaryTierManager",
)
+30
View File
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Connection configuration for the object store secondary tier."""
from dataclasses import dataclass
@dataclass
class ObjStoreConfig:
"""Connection parameters for an object store backend."""
bucket: str
endpoint_override: str
access_key: str
secret_key: str
scheme: str = "http"
ca_bundle: str = ""
def to_nixl_params(self) -> dict[str, str]:
"""Build the NIXL backend params dict."""
params: dict[str, str] = {
"bucket": self.bucket,
"endpoint_override": self.endpoint_override,
"scheme": self.scheme,
"access_key": self.access_key,
"secret_key": self.secret_key,
}
if self.ca_bundle:
params["ca_bundle"] = self.ca_bundle
return params
+256
View File
@@ -0,0 +1,256 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Object store secondary tier implementation."""
import ctypes
from collections.abc import Iterable
from typing import TYPE_CHECKING, NamedTuple
from vllm.distributed.nixl_utils import NixlWrapper as nixl_agent
from vllm.distributed.nixl_utils import nixl_agent_config
from vllm.logger import init_logger
from vllm.v1.kv_offload.base import OffloadKey, ReqContext
from vllm.v1.kv_offload.file_mapper import FileMapper
from vllm.v1.kv_offload.tiering.base import (
JobMetadata,
JobResult,
RequestOffloadingContext,
SecondaryTierManager,
)
from vllm.v1.kv_offload.tiering.obj.config import ObjStoreConfig
if TYPE_CHECKING:
from nixl._api import nixl_prepped_dlist_handle, nixl_xfer_handle
from vllm.v1.kv_offload.base import OffloadingSpec
logger = init_logger(__name__)
NIXL_WRITE = "WRITE"
NIXL_READ = "READ"
NIXL_PROC = "PROC"
NIXL_DONE = "DONE"
# Device ID for CPU DRAM descriptors. DRAM is not a multi-device resource so
# the device ID is always 0.
NIXL_DEV_ID: int = 0
# Fields for NIXL OBJ descriptors: (addr, len, dev_id, obj_key).
# For existence probes addr and len are placeholders — no data is read.
# dev_id=0 is reserved for probes; transfers start from 1.
_PROBE_ADDR: int = 0
_PROBE_LEN: int = 1
_PROBE_DEV_ID: int = 0
class TransferEntry(NamedTuple):
xfer_handle: "nixl_xfer_handle"
files_desc: object
obj_handle: "nixl_prepped_dlist_handle"
class ObjectStoreSecondaryTierManager(SecondaryTierManager):
"""Secondary tier that offloads KV cache blocks to an S3-compatible store.
Handles CPU DRAM <-> S3 transfers only. GPU <-> CPU is managed by the
primary tier. Object keys are formed as ``{prefix}/{hash_shard}/{hash}.bin``.
"""
def __init__(
self,
offloading_spec: "OffloadingSpec",
primary_kv_view: memoryview,
tier_type: str,
store_config: dict,
prefix: str = "",
io_threads: int = 4,
):
super().__init__(offloading_spec, primary_kv_view, tier_type)
agent_config = nixl_agent_config(backends=[])
self._agent = nixl_agent("ObjAgent", agent_config)
obj_config = ObjStoreConfig(**store_config)
params = {**obj_config.to_nixl_params(), "num_threads": str(io_threads)}
self._agent.create_backend("OBJ", params)
self._transfers: dict[int, TransferEntry] = {}
self._failed_jobs: list[JobResult] = []
self._primary_reg = None
self._block_size_bytes: int = 0
root_dir = f"{prefix}/" if prefix else ""
self._file_mapper = FileMapper.from_offloading_spec(root_dir, offloading_spec)
self._next_obj_dev_id: int = 1 # dev_id=0 is reserved for _exists() probes
self._probe_connectivity()
base_addr = ctypes.addressof(ctypes.c_char.from_buffer(primary_kv_view))
assert primary_kv_view.strides is not None
stride = primary_kv_view.strides[0]
self._primary_reg = self._agent.register_memory(
[(base_addr, primary_kv_view.nbytes, NIXL_DEV_ID, "")], "DRAM"
)
self._block_size_bytes = stride
all_blocks = [
(base_addr + i * stride, stride, NIXL_DEV_ID)
for i in range(len(primary_kv_view))
]
# NIXL_INIT_AGENT marks this as the local side; make_prepped_xfer requires
# local_xfer_side tagged with NIXL_INIT_AGENT and remote_xfer_side tagged
# with the peer agent name ("ObjAgent").
self._dram_prepped_handle: nixl_prepped_dlist_handle = (
self._agent.prep_xfer_dlist("NIXL_INIT_AGENT", all_blocks, "DRAM")
)
def _probe_connectivity(self) -> None:
"""Verify object store connectivity at startup via a NIXL lookup probe.
Performs a single exists() check against a synthetic key that will
never exist. A True/False result confirms the bucket is reachable;
an exception indicates misconfigured obj store params and raises RuntimeError.
"""
probe_key = "__nixl_probe__/connectivity_test"
try:
self._exists(probe_key)
logger.info("Object store tier connectivity probe succeeded")
except Exception as e:
raise RuntimeError(
f"Object store tier connectivity probe failed — check bucket, "
f"endpoint_override, access_key, secret_key, and scheme. "
f"Error: {e}"
) from e
def _exists(self, obj_key: str) -> bool:
results = self._agent.query_memory(
[(_PROBE_ADDR, _PROBE_LEN, _PROBE_DEV_ID, obj_key)], "OBJ", "OBJ"
)
return results[0] is not None
def _submit_transfer(
self,
job_id: int,
block_ids: Iterable[int],
obj_keys: Iterable[str],
op: str,
) -> None:
"""Submit an async transfer. op is 'WRITE' (store) or 'READ' (load)."""
block_ids_list = [int(bid) for bid in block_ids]
# The OBJ backend maps devId -> obj_key. All descriptors must have
# unique devIds or later registrations overwrite earlier ones.
nixl_files = [
(0, self._block_size_bytes, dev_id, key)
for dev_id, key in enumerate(obj_keys, self._next_obj_dev_id)
]
self._next_obj_dev_id += len(nixl_files)
files_desc = self._agent.register_memory(nixl_files, "OBJ")
if files_desc is None:
logger.warning("register_memory (OBJ) failed for job %d", job_id)
self._failed_jobs.append(JobResult(job_id=job_id, success=False))
return
obj_handle = self._agent.prep_xfer_dlist("ObjAgent", files_desc.trim())
if not obj_handle:
logger.warning("prep_xfer_dlist (OBJ) failed for job %d", job_id)
self._agent.deregister_memory(files_desc)
self._failed_jobs.append(JobResult(job_id=job_id, success=False))
return
xfer_handle = self._agent.make_prepped_xfer(
op,
self._dram_prepped_handle,
block_ids_list,
obj_handle,
list(range(len(nixl_files))),
)
if not xfer_handle:
logger.warning("make_prepped_xfer failed for job %d", job_id)
self._agent.release_dlist_handle(obj_handle)
self._agent.deregister_memory(files_desc)
self._failed_jobs.append(JobResult(job_id=job_id, success=False))
return
state = self._agent.transfer(xfer_handle)
if state == "ERR":
logger.warning("agent.transfer failed for job %d", job_id)
self._agent.release_dlist_handle(obj_handle)
self._agent.deregister_memory(files_desc)
self._agent.release_xfer_handle(xfer_handle)
self._failed_jobs.append(JobResult(job_id=job_id, success=False))
return
self._transfers[job_id] = TransferEntry(xfer_handle, files_desc, obj_handle)
def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None:
try:
return self._exists(self._file_mapper.get_file_name(key))
except Exception as e:
logger.warning("lookup failed for key %s: %s", key, e)
return False
def submit_store(self, job_metadata: JobMetadata) -> None:
obj_keys = (self._file_mapper.get_file_name(k) for k in job_metadata.keys)
self._submit_transfer(
job_metadata.job_id, job_metadata.block_ids, obj_keys, NIXL_WRITE
)
def submit_load(self, job_metadata: JobMetadata) -> None:
obj_keys = (self._file_mapper.get_file_name(k) for k in job_metadata.keys)
self._submit_transfer(
job_metadata.job_id, job_metadata.block_ids, obj_keys, NIXL_READ
)
def on_new_request(self, req_context: ReqContext) -> RequestOffloadingContext:
return RequestOffloadingContext()
def get_finished_jobs(self) -> Iterable[JobResult]:
"""Poll in-flight transfers; return completed (job_id, success) pairs."""
results: list[JobResult] = self._failed_jobs
self._failed_jobs = []
for job_id, entry in list(self._transfers.items()):
try:
state = self._agent.check_xfer_state(entry.xfer_handle)
except Exception as exc:
success = False
logger.warning("check_xfer_state raised for job %d: %s", job_id, exc)
else:
if state == NIXL_PROC:
continue
elif state == NIXL_DONE:
success = True
else:
success = False
logger.warning("transfer failed job=%d state=%s", job_id, state)
del self._transfers[job_id]
self._agent.release_xfer_handle(entry.xfer_handle)
self._agent.release_dlist_handle(entry.obj_handle)
self._agent.deregister_memory(entry.files_desc)
results.append(JobResult(job_id=job_id, success=success))
return results
def shutdown(self) -> None:
for job_id, entry in self._transfers.items():
try:
self._agent.release_xfer_handle(entry.xfer_handle)
except Exception as exc:
logger.warning("release_xfer_handle failed for job %d: %s", job_id, exc)
try:
self._agent.release_dlist_handle(entry.obj_handle)
except Exception as exc:
logger.warning(
"release_dlist_handle failed for job %d: %s", job_id, exc
)
try:
self._agent.deregister_memory(entry.files_desc)
except Exception as exc:
logger.warning("deregister_memory failed for job %d: %s", job_id, exc)
self._transfers.clear()
if self._dram_prepped_handle is not None:
try:
self._agent.release_dlist_handle(self._dram_prepped_handle)
except Exception as exc:
logger.warning("failed to release DRAM prepped handle: %s", exc)
self._dram_prepped_handle = None
if self._primary_reg is not None:
try:
self._agent.deregister_memory(self._primary_reg)
except Exception as exc:
logger.warning("failed to deregister primary buffer: %s", exc)
self._primary_reg = None
+2 -2
View File
@@ -131,8 +131,8 @@ class TieringOffloadingSpec(CPUOffloadingSpec):
)
except Exception as e:
logger.error(
"Failed to create secondary tier from config %s: %s",
tier_config,
"Failed to create secondary tier from config index %i: %s",
i,
e,
)
raise