mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[SharedOffloadRegion] Align blocks to page-size (#43689)
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com> Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
823d271c0d
commit
3d76f395e3
@@ -8,6 +8,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.kv_offload.base import (
|
||||
CanonicalKVCacheRef,
|
||||
@@ -90,13 +91,15 @@ def test_transfer(
|
||||
|
||||
mmap_region: SharedOffloadRegion | None = None
|
||||
if use_shared_memory:
|
||||
cpu_page_size = gpu_page_size_bytes * num_tensors * block_size_factor
|
||||
cpu_page_size = round_up(
|
||||
gpu_page_size_bytes * num_tensors * block_size_factor,
|
||||
SharedOffloadRegion.BLOCK_SIZE_ALIGNMENT,
|
||||
)
|
||||
mmap_region = SharedOffloadRegion(
|
||||
instance_id=str(uuid.uuid4()),
|
||||
total_size_bytes=num_cpu_blocks * cpu_page_size,
|
||||
num_blocks=num_cpu_blocks,
|
||||
rank=0,
|
||||
num_workers=1,
|
||||
kv_bytes_per_block=cpu_page_size,
|
||||
cpu_page_size=cpu_page_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,14 +40,12 @@ def _make_region(
|
||||
num_workers: int = 1,
|
||||
rank: int = 0,
|
||||
) -> SharedOffloadRegion:
|
||||
total_size_bytes = num_blocks * num_workers * cpu_page_size
|
||||
assert total_size_bytes % PAGE_SIZE == 0
|
||||
assert cpu_page_size % PAGE_SIZE == 0
|
||||
return SharedOffloadRegion(
|
||||
instance_id=instance_id,
|
||||
total_size_bytes=total_size_bytes,
|
||||
num_blocks=num_blocks,
|
||||
rank=rank,
|
||||
num_workers=num_workers,
|
||||
kv_bytes_per_block=num_workers * cpu_page_size,
|
||||
cpu_page_size=cpu_page_size,
|
||||
)
|
||||
|
||||
@@ -77,14 +75,12 @@ def _multi_region(
|
||||
cpu_page_size: int = PAGE_SIZE,
|
||||
):
|
||||
"""Context manager: create one SharedOffloadRegion per rank, clean up on exit."""
|
||||
total = num_blocks * num_workers * cpu_page_size
|
||||
regions = [
|
||||
SharedOffloadRegion(
|
||||
instance_id=instance_id,
|
||||
total_size_bytes=total,
|
||||
num_blocks=num_blocks,
|
||||
rank=rank,
|
||||
num_workers=num_workers,
|
||||
kv_bytes_per_block=num_workers * cpu_page_size,
|
||||
cpu_page_size=cpu_page_size,
|
||||
)
|
||||
for rank in range(num_workers)
|
||||
@@ -104,7 +100,6 @@ def _race_construct(
|
||||
cpu_page_size: int = PAGE_SIZE,
|
||||
) -> tuple[list[SharedOffloadRegion], list[Exception]]:
|
||||
"""Spawn num_workers threads that all race to construct SharedOffloadRegion."""
|
||||
total = num_blocks * num_workers * cpu_page_size
|
||||
regions: list[SharedOffloadRegion | None] = [None] * num_workers
|
||||
errors: list[Exception] = []
|
||||
barrier = threading.Barrier(num_workers)
|
||||
@@ -114,10 +109,9 @@ def _race_construct(
|
||||
try:
|
||||
regions[rank] = SharedOffloadRegion(
|
||||
instance_id=instance_id,
|
||||
total_size_bytes=total,
|
||||
num_blocks=num_blocks,
|
||||
rank=rank,
|
||||
num_workers=num_workers,
|
||||
kv_bytes_per_block=num_workers * cpu_page_size,
|
||||
cpu_page_size=cpu_page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -134,7 +128,6 @@ def _race_construct(
|
||||
|
||||
def _mp_race_construct_and_write(
|
||||
instance_id: str,
|
||||
total_bytes: int,
|
||||
num_blocks: int,
|
||||
rank: int,
|
||||
num_workers: int,
|
||||
@@ -149,10 +142,9 @@ def _mp_race_construct_and_write(
|
||||
try:
|
||||
region = SharedOffloadRegion(
|
||||
instance_id=instance_id,
|
||||
total_size_bytes=total_bytes,
|
||||
num_blocks=num_blocks,
|
||||
rank=rank,
|
||||
num_workers=num_workers,
|
||||
kv_bytes_per_block=num_workers * cpu_page_size,
|
||||
cpu_page_size=cpu_page_size,
|
||||
)
|
||||
t = region.create_next_view(cpu_page_size)
|
||||
@@ -309,7 +301,6 @@ def test_create_next_view_multiprocess_slots(iid):
|
||||
the parent verifies each slot lands at the correct interleaved offset."""
|
||||
num_workers = 2
|
||||
num_blocks = 4
|
||||
total_bytes = num_blocks * num_workers * PAGE_SIZE
|
||||
|
||||
ctx = get_mp_context()
|
||||
done_queue = ctx.Queue()
|
||||
@@ -318,10 +309,9 @@ def test_create_next_view_multiprocess_slots(iid):
|
||||
# Parent is rank 0 (creator); child is rank 1 (joiner).
|
||||
region = SharedOffloadRegion(
|
||||
instance_id=iid,
|
||||
total_size_bytes=total_bytes,
|
||||
num_blocks=num_blocks,
|
||||
rank=0,
|
||||
num_workers=num_workers,
|
||||
kv_bytes_per_block=num_workers * PAGE_SIZE,
|
||||
cpu_page_size=PAGE_SIZE,
|
||||
)
|
||||
try:
|
||||
@@ -329,7 +319,6 @@ def test_create_next_view_multiprocess_slots(iid):
|
||||
target=_mp_race_construct_and_write,
|
||||
args=(
|
||||
iid,
|
||||
total_bytes,
|
||||
num_blocks,
|
||||
1,
|
||||
num_workers,
|
||||
@@ -464,7 +453,6 @@ def test_multiprocess_race_construct_and_write(iid):
|
||||
fill_value = rank+1 into their slot; parent verifies interleaved layout."""
|
||||
num_workers = 4
|
||||
num_blocks = 3
|
||||
total_bytes = num_blocks * num_workers * PAGE_SIZE
|
||||
|
||||
ctx = get_mp_context()
|
||||
done_queue = ctx.Queue()
|
||||
@@ -475,7 +463,6 @@ def test_multiprocess_race_construct_and_write(iid):
|
||||
target=_mp_race_construct_and_write,
|
||||
args=(
|
||||
iid,
|
||||
total_bytes,
|
||||
num_blocks,
|
||||
rank,
|
||||
num_workers,
|
||||
|
||||
@@ -8,6 +8,7 @@ The tier manager writes KV cache blocks to disk and reads them back, verifying
|
||||
data integrity throughout the process.
|
||||
"""
|
||||
|
||||
import mmap
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
@@ -26,8 +27,8 @@ from vllm.v1.kv_offload.tiering.fs.manager import (
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_BLOCK_ELEMENTS = 512 * 1024 # 2 MB per block (float32 × 512K = 2MB)
|
||||
_DTYPE = torch.float32
|
||||
_BLOCK_ELEMENTS = 128 * mmap.PAGESIZE # 2MB per block for pagesize 4096.
|
||||
_DTYPE: torch.dtype = torch.float32
|
||||
_CTX = ReqContext(req_id="test")
|
||||
|
||||
_MOCK_VLLM_CONFIG = MagicMock()
|
||||
@@ -90,6 +91,32 @@ def drain(tier: FileSystemTierManager, max_rounds: int = 40) -> list:
|
||||
return results
|
||||
|
||||
|
||||
def _page_aligned_zero_tensor(
|
||||
num_blocks: int, block_elements: int, dtype: torch.dtype = _DTYPE
|
||||
) -> torch.Tensor:
|
||||
page_size = mmap.PAGESIZE
|
||||
dtype_num_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
num_bytes = num_blocks * block_elements * dtype_num_bytes
|
||||
num_bytes_aligned = num_bytes + page_size
|
||||
t = torch.zeros(num_bytes_aligned, dtype=torch.uint8)
|
||||
|
||||
ptr = t.data_ptr()
|
||||
alignment_offset = ptr % page_size
|
||||
# Move tensor to next page regardless.
|
||||
shift = page_size - alignment_offset
|
||||
t = t[shift : shift + num_bytes]
|
||||
return t.view(dtype).view(num_blocks, block_elements)
|
||||
|
||||
|
||||
def _page_aligned_rand_tensor(
|
||||
num_blocks: int, block_elements: int, dtype: torch.dtype = _DTYPE
|
||||
) -> torch.Tensor:
|
||||
rand_tensor = _page_aligned_zero_tensor(num_blocks, block_elements)
|
||||
rand_tensor[:] = torch.rand(num_blocks, block_elements, dtype=dtype)
|
||||
return rand_tensor
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -97,7 +124,7 @@ def drain(tier: FileSystemTierManager, max_rounds: int = 40) -> list:
|
||||
|
||||
@pytest.fixture
|
||||
def fs_tier(tmp_path):
|
||||
tensor = torch.zeros((4, _BLOCK_ELEMENTS), dtype=_DTYPE)
|
||||
tensor = _page_aligned_zero_tensor(4, _BLOCK_ELEMENTS)
|
||||
mock_view = memoryview(tensor.numpy())
|
||||
tier = FileSystemTierManager(
|
||||
offloading_spec=_MOCK_OFFLOADING_SPEC,
|
||||
@@ -155,7 +182,7 @@ def test_store_then_load_roundtrip(fs_tier):
|
||||
|
||||
def test_invalid_path_raises_at_construction():
|
||||
"""Construction must fail immediately when the config file cannot be written."""
|
||||
tensor = torch.zeros((32, _BLOCK_ELEMENTS), dtype=_DTYPE)
|
||||
tensor = _page_aligned_zero_tensor(32, _BLOCK_ELEMENTS)
|
||||
mock_view = memoryview(tensor.numpy())
|
||||
|
||||
with pytest.raises(OSError):
|
||||
@@ -228,7 +255,7 @@ def test_store_load_data_integrity(fs_tier):
|
||||
"""Data written by store must be exactly recovered by load."""
|
||||
tier, tensor = fs_tier
|
||||
# Populate tensor with random data
|
||||
tensor[:] = torch.rand((4, _BLOCK_ELEMENTS), dtype=_DTYPE)
|
||||
tensor[:] = _page_aligned_rand_tensor(4, _BLOCK_ELEMENTS)
|
||||
|
||||
# Store first 2 blocks
|
||||
num_store = 2
|
||||
|
||||
Reference in New Issue
Block a user