[SharedOffloadRegion] Align blocks to page-size (#43689)

Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com>
Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2026-06-03 07:25:57 -04:00
committed by GitHub
parent 823d271c0d
commit 3d76f395e3
6 changed files with 82 additions and 58 deletions
+6 -3
View File
@@ -8,6 +8,7 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.kv_offload.base import (
CanonicalKVCacheRef,
@@ -90,13 +91,15 @@ def test_transfer(
mmap_region: SharedOffloadRegion | None = None
if use_shared_memory:
cpu_page_size = gpu_page_size_bytes * num_tensors * block_size_factor
cpu_page_size = round_up(
gpu_page_size_bytes * num_tensors * block_size_factor,
SharedOffloadRegion.BLOCK_SIZE_ALIGNMENT,
)
mmap_region = SharedOffloadRegion(
instance_id=str(uuid.uuid4()),
total_size_bytes=num_cpu_blocks * cpu_page_size,
num_blocks=num_cpu_blocks,
rank=0,
num_workers=1,
kv_bytes_per_block=cpu_page_size,
cpu_page_size=cpu_page_size,
)
@@ -40,14 +40,12 @@ def _make_region(
num_workers: int = 1,
rank: int = 0,
) -> SharedOffloadRegion:
total_size_bytes = num_blocks * num_workers * cpu_page_size
assert total_size_bytes % PAGE_SIZE == 0
assert cpu_page_size % PAGE_SIZE == 0
return SharedOffloadRegion(
instance_id=instance_id,
total_size_bytes=total_size_bytes,
num_blocks=num_blocks,
rank=rank,
num_workers=num_workers,
kv_bytes_per_block=num_workers * cpu_page_size,
cpu_page_size=cpu_page_size,
)
@@ -77,14 +75,12 @@ def _multi_region(
cpu_page_size: int = PAGE_SIZE,
):
"""Context manager: create one SharedOffloadRegion per rank, clean up on exit."""
total = num_blocks * num_workers * cpu_page_size
regions = [
SharedOffloadRegion(
instance_id=instance_id,
total_size_bytes=total,
num_blocks=num_blocks,
rank=rank,
num_workers=num_workers,
kv_bytes_per_block=num_workers * cpu_page_size,
cpu_page_size=cpu_page_size,
)
for rank in range(num_workers)
@@ -104,7 +100,6 @@ def _race_construct(
cpu_page_size: int = PAGE_SIZE,
) -> tuple[list[SharedOffloadRegion], list[Exception]]:
"""Spawn num_workers threads that all race to construct SharedOffloadRegion."""
total = num_blocks * num_workers * cpu_page_size
regions: list[SharedOffloadRegion | None] = [None] * num_workers
errors: list[Exception] = []
barrier = threading.Barrier(num_workers)
@@ -114,10 +109,9 @@ def _race_construct(
try:
regions[rank] = SharedOffloadRegion(
instance_id=instance_id,
total_size_bytes=total,
num_blocks=num_blocks,
rank=rank,
num_workers=num_workers,
kv_bytes_per_block=num_workers * cpu_page_size,
cpu_page_size=cpu_page_size,
)
except Exception as e:
@@ -134,7 +128,6 @@ def _race_construct(
def _mp_race_construct_and_write(
instance_id: str,
total_bytes: int,
num_blocks: int,
rank: int,
num_workers: int,
@@ -149,10 +142,9 @@ def _mp_race_construct_and_write(
try:
region = SharedOffloadRegion(
instance_id=instance_id,
total_size_bytes=total_bytes,
num_blocks=num_blocks,
rank=rank,
num_workers=num_workers,
kv_bytes_per_block=num_workers * cpu_page_size,
cpu_page_size=cpu_page_size,
)
t = region.create_next_view(cpu_page_size)
@@ -309,7 +301,6 @@ def test_create_next_view_multiprocess_slots(iid):
the parent verifies each slot lands at the correct interleaved offset."""
num_workers = 2
num_blocks = 4
total_bytes = num_blocks * num_workers * PAGE_SIZE
ctx = get_mp_context()
done_queue = ctx.Queue()
@@ -318,10 +309,9 @@ def test_create_next_view_multiprocess_slots(iid):
# Parent is rank 0 (creator); child is rank 1 (joiner).
region = SharedOffloadRegion(
instance_id=iid,
total_size_bytes=total_bytes,
num_blocks=num_blocks,
rank=0,
num_workers=num_workers,
kv_bytes_per_block=num_workers * PAGE_SIZE,
cpu_page_size=PAGE_SIZE,
)
try:
@@ -329,7 +319,6 @@ def test_create_next_view_multiprocess_slots(iid):
target=_mp_race_construct_and_write,
args=(
iid,
total_bytes,
num_blocks,
1,
num_workers,
@@ -464,7 +453,6 @@ def test_multiprocess_race_construct_and_write(iid):
fill_value = rank+1 into their slot; parent verifies interleaved layout."""
num_workers = 4
num_blocks = 3
total_bytes = num_blocks * num_workers * PAGE_SIZE
ctx = get_mp_context()
done_queue = ctx.Queue()
@@ -475,7 +463,6 @@ def test_multiprocess_race_construct_and_write(iid):
target=_mp_race_construct_and_write,
args=(
iid,
total_bytes,
num_blocks,
rank,
num_workers,
+32 -5
View File
@@ -8,6 +8,7 @@ The tier manager writes KV cache blocks to disk and reads them back, verifying
data integrity throughout the process.
"""
import mmap
import os
import time
from unittest.mock import MagicMock
@@ -26,8 +27,8 @@ from vllm.v1.kv_offload.tiering.fs.manager import (
# Helpers
# ---------------------------------------------------------------------------
_BLOCK_ELEMENTS = 512 * 1024 # 2 MB per block (float32 × 512K = 2MB)
_DTYPE = torch.float32
_BLOCK_ELEMENTS = 128 * mmap.PAGESIZE # 2MB per block for pagesize 4096.
_DTYPE: torch.dtype = torch.float32
_CTX = ReqContext(req_id="test")
_MOCK_VLLM_CONFIG = MagicMock()
@@ -90,6 +91,32 @@ def drain(tier: FileSystemTierManager, max_rounds: int = 40) -> list:
return results
def _page_aligned_zero_tensor(
num_blocks: int, block_elements: int, dtype: torch.dtype = _DTYPE
) -> torch.Tensor:
page_size = mmap.PAGESIZE
dtype_num_bytes = torch.tensor([], dtype=dtype).element_size()
num_bytes = num_blocks * block_elements * dtype_num_bytes
num_bytes_aligned = num_bytes + page_size
t = torch.zeros(num_bytes_aligned, dtype=torch.uint8)
ptr = t.data_ptr()
alignment_offset = ptr % page_size
# Move tensor to next page regardless.
shift = page_size - alignment_offset
t = t[shift : shift + num_bytes]
return t.view(dtype).view(num_blocks, block_elements)
def _page_aligned_rand_tensor(
num_blocks: int, block_elements: int, dtype: torch.dtype = _DTYPE
) -> torch.Tensor:
rand_tensor = _page_aligned_zero_tensor(num_blocks, block_elements)
rand_tensor[:] = torch.rand(num_blocks, block_elements, dtype=dtype)
return rand_tensor
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -97,7 +124,7 @@ def drain(tier: FileSystemTierManager, max_rounds: int = 40) -> list:
@pytest.fixture
def fs_tier(tmp_path):
tensor = torch.zeros((4, _BLOCK_ELEMENTS), dtype=_DTYPE)
tensor = _page_aligned_zero_tensor(4, _BLOCK_ELEMENTS)
mock_view = memoryview(tensor.numpy())
tier = FileSystemTierManager(
offloading_spec=_MOCK_OFFLOADING_SPEC,
@@ -155,7 +182,7 @@ def test_store_then_load_roundtrip(fs_tier):
def test_invalid_path_raises_at_construction():
"""Construction must fail immediately when the config file cannot be written."""
tensor = torch.zeros((32, _BLOCK_ELEMENTS), dtype=_DTYPE)
tensor = _page_aligned_zero_tensor(32, _BLOCK_ELEMENTS)
mock_view = memoryview(tensor.numpy())
with pytest.raises(OSError):
@@ -228,7 +255,7 @@ def test_store_load_data_integrity(fs_tier):
"""Data written by store must be exactly recovered by load."""
tier, tensor = fs_tier
# Populate tensor with random data
tensor[:] = torch.rand((4, _BLOCK_ELEMENTS), dtype=_DTYPE)
tensor[:] = _page_aligned_rand_tensor(4, _BLOCK_ELEMENTS)
# Store first 2 blocks
num_store = 2