[SharedOffloadRegion] Align blocks to page-size (#43689)

Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2026-06-03 07:25:57 -04:00
committed by GitHub
parent 823d271c0d
commit 3d76f395e3
6 changed files with 82 additions and 58 deletions
+6 -3
View File
@@ -8,6 +8,7 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.kv_offload.base import (
CanonicalKVCacheRef,
@@ -90,13 +91,15 @@ def test_transfer(
mmap_region: SharedOffloadRegion | None = None
if use_shared_memory:
cpu_page_size = gpu_page_size_bytes * num_tensors * block_size_factor
cpu_page_size = round_up(
gpu_page_size_bytes * num_tensors * block_size_factor,
SharedOffloadRegion.BLOCK_SIZE_ALIGNMENT,
)
mmap_region = SharedOffloadRegion(
instance_id=str(uuid.uuid4()),
total_size_bytes=num_cpu_blocks * cpu_page_size,
num_blocks=num_cpu_blocks,
rank=0,
num_workers=1,
kv_bytes_per_block=cpu_page_size,
cpu_page_size=cpu_page_size,
)
@@ -40,14 +40,12 @@ def _make_region(
num_workers: int = 1,
rank: int = 0,
) -> SharedOffloadRegion:
total_size_bytes = num_blocks * num_workers * cpu_page_size
assert total_size_bytes % PAGE_SIZE == 0
assert cpu_page_size % PAGE_SIZE == 0
return SharedOffloadRegion(
instance_id=instance_id,
total_size_bytes=total_size_bytes,
num_blocks=num_blocks,
rank=rank,
num_workers=num_workers,
kv_bytes_per_block=num_workers * cpu_page_size,
cpu_page_size=cpu_page_size,
)
@@ -77,14 +75,12 @@ def _multi_region(
cpu_page_size: int = PAGE_SIZE,
):
"""Context manager: create one SharedOffloadRegion per rank, clean up on exit."""
total = num_blocks * num_workers * cpu_page_size
regions = [
SharedOffloadRegion(
instance_id=instance_id,
total_size_bytes=total,
num_blocks=num_blocks,
rank=rank,
num_workers=num_workers,
kv_bytes_per_block=num_workers * cpu_page_size,
cpu_page_size=cpu_page_size,
)
for rank in range(num_workers)
@@ -104,7 +100,6 @@ def _race_construct(
cpu_page_size: int = PAGE_SIZE,
) -> tuple[list[SharedOffloadRegion], list[Exception]]:
"""Spawn num_workers threads that all race to construct SharedOffloadRegion."""
total = num_blocks * num_workers * cpu_page_size
regions: list[SharedOffloadRegion | None] = [None] * num_workers
errors: list[Exception] = []
barrier = threading.Barrier(num_workers)
@@ -114,10 +109,9 @@ def _race_construct(
try:
regions[rank] = SharedOffloadRegion(
instance_id=instance_id,
total_size_bytes=total,
num_blocks=num_blocks,
rank=rank,
num_workers=num_workers,
kv_bytes_per_block=num_workers * cpu_page_size,
cpu_page_size=cpu_page_size,
)
except Exception as e:
@@ -134,7 +128,6 @@ def _race_construct(
def _mp_race_construct_and_write(
instance_id: str,
total_bytes: int,
num_blocks: int,
rank: int,
num_workers: int,
@@ -149,10 +142,9 @@ def _mp_race_construct_and_write(
try:
region = SharedOffloadRegion(
instance_id=instance_id,
total_size_bytes=total_bytes,
num_blocks=num_blocks,
rank=rank,
num_workers=num_workers,
kv_bytes_per_block=num_workers * cpu_page_size,
cpu_page_size=cpu_page_size,
)
t = region.create_next_view(cpu_page_size)
@@ -309,7 +301,6 @@ def test_create_next_view_multiprocess_slots(iid):
the parent verifies each slot lands at the correct interleaved offset."""
num_workers = 2
num_blocks = 4
total_bytes = num_blocks * num_workers * PAGE_SIZE
ctx = get_mp_context()
done_queue = ctx.Queue()
@@ -318,10 +309,9 @@ def test_create_next_view_multiprocess_slots(iid):
# Parent is rank 0 (creator); child is rank 1 (joiner).
region = SharedOffloadRegion(
instance_id=iid,
total_size_bytes=total_bytes,
num_blocks=num_blocks,
rank=0,
num_workers=num_workers,
kv_bytes_per_block=num_workers * PAGE_SIZE,
cpu_page_size=PAGE_SIZE,
)
try:
@@ -329,7 +319,6 @@ def test_create_next_view_multiprocess_slots(iid):
target=_mp_race_construct_and_write,
args=(
iid,
total_bytes,
num_blocks,
1,
num_workers,
@@ -464,7 +453,6 @@ def test_multiprocess_race_construct_and_write(iid):
fill_value = rank+1 into their slot; parent verifies interleaved layout."""
num_workers = 4
num_blocks = 3
total_bytes = num_blocks * num_workers * PAGE_SIZE
ctx = get_mp_context()
done_queue = ctx.Queue()
@@ -475,7 +463,6 @@ def test_multiprocess_race_construct_and_write(iid):
target=_mp_race_construct_and_write,
args=(
iid,
total_bytes,
num_blocks,
rank,
num_workers,
+32 -5
View File
@@ -8,6 +8,7 @@ The tier manager writes KV cache blocks to disk and reads them back, verifying
data integrity throughout the process.
"""
import mmap
import os
import time
from unittest.mock import MagicMock
@@ -26,8 +27,8 @@ from vllm.v1.kv_offload.tiering.fs.manager import (
# Helpers
# ---------------------------------------------------------------------------
_BLOCK_ELEMENTS = 512 * 1024 # 2 MB per block (float32 × 512K = 2MB)
_DTYPE = torch.float32
_BLOCK_ELEMENTS = 128 * mmap.PAGESIZE # 2MB per block for pagesize 4096.
_DTYPE: torch.dtype = torch.float32
_CTX = ReqContext(req_id="test")
_MOCK_VLLM_CONFIG = MagicMock()
@@ -90,6 +91,32 @@ def drain(tier: FileSystemTierManager, max_rounds: int = 40) -> list:
return results
def _page_aligned_zero_tensor(
num_blocks: int, block_elements: int, dtype: torch.dtype = _DTYPE
) -> torch.Tensor:
page_size = mmap.PAGESIZE
dtype_num_bytes = torch.tensor([], dtype=dtype).element_size()
num_bytes = num_blocks * block_elements * dtype_num_bytes
num_bytes_aligned = num_bytes + page_size
t = torch.zeros(num_bytes_aligned, dtype=torch.uint8)
ptr = t.data_ptr()
alignment_offset = ptr % page_size
# Move tensor to next page regardless.
shift = page_size - alignment_offset
t = t[shift : shift + num_bytes]
return t.view(dtype).view(num_blocks, block_elements)
def _page_aligned_rand_tensor(
num_blocks: int, block_elements: int, dtype: torch.dtype = _DTYPE
) -> torch.Tensor:
rand_tensor = _page_aligned_zero_tensor(num_blocks, block_elements)
rand_tensor[:] = torch.rand(num_blocks, block_elements, dtype=dtype)
return rand_tensor
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -97,7 +124,7 @@ def drain(tier: FileSystemTierManager, max_rounds: int = 40) -> list:
@pytest.fixture
def fs_tier(tmp_path):
tensor = torch.zeros((4, _BLOCK_ELEMENTS), dtype=_DTYPE)
tensor = _page_aligned_zero_tensor(4, _BLOCK_ELEMENTS)
mock_view = memoryview(tensor.numpy())
tier = FileSystemTierManager(
offloading_spec=_MOCK_OFFLOADING_SPEC,
@@ -155,7 +182,7 @@ def test_store_then_load_roundtrip(fs_tier):
def test_invalid_path_raises_at_construction():
"""Construction must fail immediately when the config file cannot be written."""
tensor = torch.zeros((32, _BLOCK_ELEMENTS), dtype=_DTYPE)
tensor = _page_aligned_zero_tensor(32, _BLOCK_ELEMENTS)
mock_view = memoryview(tensor.numpy())
with pytest.raises(OSError):
@@ -228,7 +255,7 @@ def test_store_load_data_integrity(fs_tier):
"""Data written by store must be exactly recovered by load."""
tier, tensor = fs_tier
# Populate tensor with random data
tensor[:] = torch.rand((4, _BLOCK_ELEMENTS), dtype=_DTYPE)
tensor[:] = _page_aligned_rand_tensor(4, _BLOCK_ELEMENTS)
# Store first 2 blocks
num_store = 2
@@ -35,24 +35,26 @@ class SharedOffloadRegion:
File path: /dev/shm/vllm_offload_{instance_id}.mmap
"""
BLOCK_SIZE_ALIGNMENT: int = mmap.PAGESIZE
def __init__(
self,
instance_id: str,
total_size_bytes: int,
num_blocks: int,
rank: int | None,
num_workers: int,
kv_bytes_per_block: int,
cpu_page_size: int,
) -> None:
self.page_size = mmap.PAGESIZE
assert kv_bytes_per_block % self.page_size == 0
self.num_blocks = num_blocks
self._row_stride = kv_bytes_per_block
self.total_size_bytes = self.num_blocks * self._row_stride
self.total_size_bytes = total_size_bytes
self.mmap_path = f"/dev/shm/vllm_offload_{instance_id}.mmap"
self._creator = False # set True only if this worker creates the file
self.num_blocks = num_blocks
self.rank = rank
# interleaved-layout stride: one row = all workers' data for one block
self._row_stride = cpu_page_size * num_workers
if rank is not None:
# byte offset to this worker's first slot within each block row
self._worker_offset = rank * cpu_page_size
+26 -15
View File
@@ -6,6 +6,7 @@ from typing_extensions import override
from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_offload.base import (
CanonicalKVCaches,
@@ -21,6 +22,8 @@ from vllm.v1.kv_offload.worker.worker import OffloadingHandler
class CPUOffloadingSpec(OffloadingSpec):
BLOCK_SIZE_ALIGNMENT = 1
def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):
super().__init__(vllm_config, kv_cache_config)
@@ -30,26 +33,34 @@ class CPUOffloadingSpec(OffloadingSpec):
"cpu_bytes_to_use must be specified in kv_connector_extra_config"
)
# calculate kv_bytes_per_offloaded_block
world_size = vllm_config.parallel_config.world_size
self.num_blocks = 0
self.kv_bytes_per_offloaded_block = 0
self.cpu_page_size_per_worker = 0
assert kv_cache_config is not None
if kv_cache_config.num_blocks > 0:
if kv_cache_config.num_blocks > 0 and world_size > 0:
total_gpu_kv_bytes = sum(t.size for t in kv_cache_config.kv_cache_tensors)
kv_bytes_per_block = (
total_gpu_kv_bytes // kv_cache_config.num_blocks
) * vllm_config.parallel_config.world_size
else:
kv_bytes_per_block = 0
) * world_size
kv_bytes_per_offloaded_block = kv_bytes_per_block * self.block_size_factor
kv_bytes_per_offloaded_block = kv_bytes_per_block * self.block_size_factor
self.num_blocks = (
int(cpu_bytes_to_use) // kv_bytes_per_offloaded_block
if kv_bytes_per_offloaded_block > 0
else 0
)
world_size = vllm_config.parallel_config.world_size
self.cpu_page_size_per_worker: int = (
kv_bytes_per_offloaded_block // world_size if world_size > 0 else 0
)
# calculate cpu_page_size_per_worker
self.cpu_page_size_per_worker = kv_bytes_per_offloaded_block // world_size
# calculate num_blocks
aligned_kv_bytes_per_offloaded_block = round_up(
kv_bytes_per_offloaded_block, self.BLOCK_SIZE_ALIGNMENT
)
self.num_blocks = (
int(cpu_bytes_to_use) // aligned_kv_bytes_per_offloaded_block
)
# Expose aligned_kv_bytes_per_offloaded_block as
# kv_bytes_per_offloaded_block. Note that this might contain
# some padding. i.e. each offloaded block is of the form,
# |--- W0-B0---|---- W1-B0---| ... |---- Wn-B0---| *** maybe-pad *** |
self.kv_bytes_per_offloaded_block = aligned_kv_bytes_per_offloaded_block
# scheduler-side
self._manager: OffloadingManager | None = None
+4 -10
View File
@@ -63,6 +63,8 @@ class TieringOffloadingSpec(CPUOffloadingSpec):
memory and must transfer data through the primary tier.
"""
BLOCK_SIZE_ALIGNMENT = SharedOffloadRegion.BLOCK_SIZE_ALIGNMENT
def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):
super().__init__(vllm_config, kv_cache_config)
# Redeclare for mypy: parent sets this but `--follow-imports skip` hides it
@@ -96,15 +98,11 @@ class TieringOffloadingSpec(CPUOffloadingSpec):
# Create scheduler-side SharedOffloadRegion (rank=None) so the
# primary tier can eagerly create a memoryview over _base.
world_size = self.vllm_config.parallel_config.world_size
scheduler_mmap = SharedOffloadRegion(
instance_id=self.vllm_config.instance_id,
total_size_bytes=self.cpu_page_size_per_worker
* world_size
* self.num_blocks,
num_blocks=self.num_blocks,
rank=None,
num_workers=world_size,
kv_bytes_per_block=self.kv_bytes_per_offloaded_block,
cpu_page_size=self.cpu_page_size_per_worker,
)
self._scheduler_mmap = scheduler_mmap
@@ -165,16 +163,12 @@ class TieringOffloadingSpec(CPUOffloadingSpec):
@override
def create_handlers(self, kv_caches: CanonicalKVCaches) -> CpuGpuOffloadingHandlers:
world_size = self.vllm_config.parallel_config.world_size
rank = torch.accelerator.current_device_index()
worker_mmap = SharedOffloadRegion(
instance_id=self.vllm_config.instance_id,
total_size_bytes=self.cpu_page_size_per_worker
* world_size
* self.num_blocks,
num_blocks=self.num_blocks,
rank=rank,
num_workers=world_size,
kv_bytes_per_block=self.kv_bytes_per_offloaded_block,
cpu_page_size=self.cpu_page_size_per_worker,
)
return CpuGpuOffloadingHandlers(