mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[KV Connector] Remove compat support for pre-v0.12.0 constructor signatures without KVCacheConfig (#39832)
The v0.12.0 release contained initial support for HMA in KV Connectors. As part of these changes, a KVCacheConfig argument was added to KV connector constructors. Backwards compatibility support for out-of-tree connectors was included in this change, with a very prominent warning. See #25712 and #27887. Since the warning has been around for over 5 months, we can safely remove the support of it. Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
+12
-2
@@ -20,6 +20,7 @@ from vllm.v1.request import Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -35,8 +36,17 @@ class LoadRecoveryExampleConnectorMetadata(ExampleConnectorMetadata):
|
||||
|
||||
|
||||
class LoadRecoveryExampleConnector(ExampleConnector):
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"async_load", False
|
||||
)
|
||||
|
||||
@@ -66,7 +66,7 @@ class DummyKVConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
# Get the status file path from extra config
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for backwards compatibility with external KV connector implementations.
|
||||
|
||||
This test ensures that external connectors (loaded via kv_connector_module_path)
|
||||
implemented with the old signature continue to work:
|
||||
- Old signature: __init__(self, vllm_config, role)
|
||||
- New signature: __init__(self, vllm_config, role, kv_cache_config)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
from .utils import create_scheduler, create_vllm_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class OldStyleTestConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Test connector using the old signature with 2 required arguments.
|
||||
This simulates external connectors that haven't been updated yet.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
# Old-style call to super().__init__ with only 2 arguments
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||
return None
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
class NewStyleTestConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Test connector using the new signature with 3 required arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
# New-style call to super().__init__ with all 3 arguments
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||
return None
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||
def test_external_old_signature_factory_instantiation(role):
|
||||
"""
|
||||
Test that external connectors with old signature (2 required args) loaded
|
||||
via kv_connector_module_path are correctly instantiated with backwards
|
||||
compatibility support.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
|
||||
vllm_config.kv_transfer_config.kv_connector_module_path = (
|
||||
"tests.v1.kv_connector.unit.test_backwards_compatibility"
|
||||
)
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert connector is not None
|
||||
assert isinstance(connector, OldStyleTestConnector)
|
||||
assert connector.role == role
|
||||
assert connector._kv_cache_config is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||
def test_external_new_signature_factory_instantiation(role):
|
||||
"""
|
||||
Test that external connectors with new signature (3 required args) loaded
|
||||
via kv_connector_module_path are correctly instantiated.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
|
||||
vllm_config.kv_transfer_config.kv_connector_module_path = (
|
||||
"tests.v1.kv_connector.unit.test_backwards_compatibility"
|
||||
)
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert connector is not None
|
||||
assert isinstance(connector, NewStyleTestConnector)
|
||||
assert connector.role == role
|
||||
assert connector._kv_cache_config is not None
|
||||
assert connector._kv_cache_config == kv_cache_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||
def test_old_signature_super_init(role):
|
||||
"""
|
||||
Test that old-style connectors can call super().__init__() without
|
||||
kv_cache_config parameter.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
connector = OldStyleTestConnector(vllm_config, role)
|
||||
|
||||
assert connector is not None
|
||||
assert connector.role == role
|
||||
assert connector._kv_cache_config is None
|
||||
|
||||
|
||||
def test_old_signature_super_init_with_kwargs():
|
||||
"""
|
||||
Test that old-style connectors can call super().__init__() with keyword
|
||||
arguments in different orders.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test with vllm_config= and role= kwargs
|
||||
connector1 = OldStyleTestConnector(
|
||||
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
|
||||
)
|
||||
assert connector1 is not None
|
||||
assert connector1._kv_cache_config is None
|
||||
|
||||
# Test with role= and vllm_config= in reversed order
|
||||
connector2 = OldStyleTestConnector(
|
||||
role=KVConnectorRole.WORKER, vllm_config=vllm_config
|
||||
)
|
||||
assert connector2 is not None
|
||||
assert connector2._kv_cache_config is None
|
||||
|
||||
|
||||
def test_internal_connector_uses_new_signature():
|
||||
"""
|
||||
Test that internal connectors (registered in factory) always use the new
|
||||
signature and get kv_cache_config.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
|
||||
ExampleConnector,
|
||||
)
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "ExampleConnector"
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
connector = KVConnectorFactory.create_connector(
|
||||
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||
)
|
||||
|
||||
assert connector is not None
|
||||
assert isinstance(connector, ExampleConnector)
|
||||
assert connector._kv_cache_config is not None
|
||||
assert connector._kv_cache_config == kv_cache_config
|
||||
|
||||
|
||||
def test_signature_detection_with_mocking():
|
||||
"""
|
||||
Test that the factory correctly applies compat_sig flag returned from
|
||||
_get_connector_class_with_compat.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
# Mock _get_connector_class_with_compat to return old-style connector
|
||||
with patch.object(
|
||||
KVConnectorFactory,
|
||||
"_get_connector_class_with_compat",
|
||||
return_value=(OldStyleTestConnector, True),
|
||||
):
|
||||
old_connector = KVConnectorFactory.create_connector(
|
||||
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||
)
|
||||
assert old_connector is not None
|
||||
assert isinstance(old_connector, OldStyleTestConnector)
|
||||
assert old_connector._kv_cache_config is None
|
||||
|
||||
# Mock _get_connector_class_with_compat to return new-style connector
|
||||
with patch.object(
|
||||
KVConnectorFactory,
|
||||
"_get_connector_class_with_compat",
|
||||
return_value=(NewStyleTestConnector, False),
|
||||
):
|
||||
new_connector = KVConnectorFactory.create_connector(
|
||||
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||
)
|
||||
assert new_connector is not None
|
||||
assert isinstance(new_connector, NewStyleTestConnector)
|
||||
assert new_connector._kv_cache_config is not None
|
||||
assert new_connector._kv_cache_config == kv_cache_config
|
||||
@@ -58,7 +58,9 @@ class DecodeBenchTestRunner:
|
||||
|
||||
# Create worker-side connector
|
||||
self.worker_connector = DecodeBenchConnector(
|
||||
vllm_config, KVConnectorRole.WORKER
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
self.scheduler.kv_cache_config,
|
||||
)
|
||||
|
||||
# Create dummy KV caches for testing
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
get_kv_transfer_group,
|
||||
)
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
|
||||
# Importing utils registers TestExampleConnector with the factory
|
||||
@@ -38,7 +39,10 @@ def test_kv_connector_mixin_clears_metadata():
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
|
||||
|
||||
# Initialize the global connector instance
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[]
|
||||
)
|
||||
ensure_kv_transfer_initialized(vllm_config, kv_cache_config)
|
||||
|
||||
try:
|
||||
# Minimal scheduler output with empty metadata; mixin should still
|
||||
|
||||
@@ -26,11 +26,16 @@ from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import
|
||||
)
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
|
||||
def _make_test_kv_cache_config() -> KVCacheConfig:
|
||||
return KVCacheConfig(num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[])
|
||||
|
||||
|
||||
class FakeMooncakeWrapper:
|
||||
"""Mock Mooncake TransferEngine for unit testing environments."""
|
||||
|
||||
@@ -321,7 +326,11 @@ async def test_kv_producer(monkeypatch):
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
prefill_connector = MooncakeConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
_make_test_kv_cache_config(),
|
||||
)
|
||||
prefill_worker = prefill_connector.connector_worker
|
||||
prefill_worker.kv_caches_base_addr = [0x1000]
|
||||
block_len = 4096
|
||||
@@ -473,7 +482,11 @@ async def test_kv_consumuer(monkeypatch):
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies() as mocks:
|
||||
decode_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
decode_connector = MooncakeConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
_make_test_kv_cache_config(),
|
||||
)
|
||||
decode_worker = decode_connector.connector_worker
|
||||
decode_worker.kv_caches_base_addr = [0x1000]
|
||||
decode_worker.rpc_port = 54321
|
||||
@@ -533,7 +546,11 @@ async def test_worker_get_finished_timeout(monkeypatch):
|
||||
kv_connector="MooncakeConnector", kv_role="kv_producer"
|
||||
)
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
prefill_connector = MooncakeConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
_make_test_kv_cache_config(),
|
||||
)
|
||||
prefill_worker = prefill_connector.connector_worker
|
||||
|
||||
# Add an expired request (expire_time is in the past).
|
||||
@@ -579,7 +596,11 @@ def test_register_kv_caches():
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Thread"
|
||||
) as mock_thread,
|
||||
):
|
||||
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = MooncakeConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
_make_test_kv_cache_config(),
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
mock_thread.return_value.is_alive.return_value = False
|
||||
|
||||
@@ -628,7 +649,11 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes():
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Thread"
|
||||
) as mock_thread,
|
||||
):
|
||||
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = MooncakeConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
_make_test_kv_cache_config(),
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
mock_thread.return_value.is_alive.return_value = False
|
||||
|
||||
@@ -688,7 +713,11 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
prefill_connector = MooncakeConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
_make_test_kv_cache_config(),
|
||||
)
|
||||
prefill_worker = prefill_connector.connector_worker
|
||||
|
||||
# Override TP rank/size to simulate P TP=2
|
||||
|
||||
@@ -221,9 +221,14 @@ async def test_build_transfer_params_multi_group_trimming(monkeypatch):
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_producer"
|
||||
)
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=vllm_config.cache_config.block_size, swa_enabled=True
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = MooncakeConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, kv_cache_config
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
|
||||
block_len = 4096
|
||||
@@ -304,9 +309,14 @@ async def test_build_transfer_params_group_count_mismatch(monkeypatch):
|
||||
vllm_config = create_vllm_config(
|
||||
kv_connector="MooncakeConnector", kv_role="kv_producer"
|
||||
)
|
||||
kv_cache_config = make_kv_cache_config(
|
||||
block_size=vllm_config.cache_config.block_size, swa_enabled=True
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
|
||||
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = MooncakeConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, kv_cache_config
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
|
||||
block_len = 4096
|
||||
|
||||
@@ -37,9 +37,15 @@ from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
make_zmq_path,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
from .utils import create_request, create_scheduler
|
||||
|
||||
|
||||
def _make_test_kv_cache_config() -> KVCacheConfig:
|
||||
return KVCacheConfig(num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[])
|
||||
|
||||
|
||||
aiter_available = importlib.util.find_spec("aiter") is not None
|
||||
mori_available = importlib.util.find_spec("mori") is not None
|
||||
|
||||
@@ -462,7 +468,11 @@ def test_register_kv_caches(mock_parallel_groups):
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = MoRIIOConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
_make_test_kv_cache_config(),
|
||||
)
|
||||
connector.connector_worker = FakeMoRIIOConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
@@ -554,7 +564,11 @@ def test_moriio_handshake_returns_metadata(mock_parallel_groups):
|
||||
}
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector = MoRIIOConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
_make_test_kv_cache_config(),
|
||||
)
|
||||
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@@ -26,7 +27,7 @@ class DummyConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
|
||||
class DummyKVConnector(KVConnectorBase_V1):
|
||||
def __init__(self, vllm_config, role, kv_cache_config=None):
|
||||
def __init__(self, vllm_config, role, kv_cache_config: KVCacheConfig):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
|
||||
@@ -293,9 +293,14 @@ def create_model_runner_output(
|
||||
|
||||
|
||||
class TestExampleConnector(ExampleConnector):
|
||||
def __init__(self, config: VllmConfig, role, kv_cache_config):
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
):
|
||||
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
||||
self._connector = ExampleConnector(config, role)
|
||||
self._connector = ExampleConnector(config, role, kv_cache_config)
|
||||
self.call_record: dict[str, int] = defaultdict(int)
|
||||
# Use a unique temp file per connector
|
||||
self._event_file = (
|
||||
@@ -368,7 +373,7 @@ class MockKVConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
extra_config = self._kv_transfer_config.kv_connector_extra_config
|
||||
|
||||
@@ -44,14 +44,12 @@ class KVConnectorFactory:
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
) -> KVConnectorBase:
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
if kv_transfer_config is None:
|
||||
raise ValueError("kv_transfer_config must be set to create a connector")
|
||||
connector_cls, compat_sig = cls._get_connector_class_with_compat(
|
||||
kv_transfer_config
|
||||
)
|
||||
connector_cls = cls.get_connector_class(kv_transfer_config)
|
||||
|
||||
# check if the connector supports HMA
|
||||
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
|
||||
@@ -74,12 +72,7 @@ class KVConnectorFactory:
|
||||
# - Co-locate with worker process
|
||||
# - Should only be used inside the forward context & attention layer
|
||||
# We build separately to enforce strict separation
|
||||
if compat_sig:
|
||||
# Old signature: __init__(self, vllm_config, role)
|
||||
return connector_cls(config, role)
|
||||
else:
|
||||
# New signature: __init__(self, vllm_config, role, kv_cache_config)
|
||||
return connector_cls(config, role, kv_cache_config)
|
||||
return connector_cls(config, role, kv_cache_config)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class_by_name(
|
||||
@@ -100,13 +93,12 @@ class KVConnectorFactory:
|
||||
return cls._registry[connector_name]()
|
||||
|
||||
@classmethod
|
||||
def _get_connector_class_with_compat(
|
||||
def get_connector_class(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> tuple[type[KVConnectorBaseType], bool]:
|
||||
) -> type[KVConnectorBaseType]:
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name is None:
|
||||
raise ValueError("Connector name is not set in KVTransferConfig")
|
||||
compat_sig = False
|
||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||
if connector_module_path is not None and not connector_module_path:
|
||||
raise ValueError("kv_connector_module_path cannot be an empty string.")
|
||||
@@ -121,24 +113,18 @@ class KVConnectorFactory:
|
||||
) from e
|
||||
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
|
||||
if not supports_kw(connector_cls, "kv_cache_config"):
|
||||
compat_sig = True
|
||||
logger.warning(
|
||||
"Connector %s uses deprecated signature with 2 required arguments. "
|
||||
"Please update to include kv_cache_config as the second argument.",
|
||||
connector_cls.__name__,
|
||||
msg = (
|
||||
f"Connector {connector_cls.__name__} uses deprecated "
|
||||
"2-argument constructor signature. External v1 KV "
|
||||
"connectors must accept kv_cache_config as the third "
|
||||
"constructor argument and pass it to super().__init__()."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
elif connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
return connector_cls, compat_sig
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get the connector class by name."""
|
||||
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
|
||||
return connector_cls
|
||||
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ class KVConnectorBase_V1(ABC):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
logger.warning(
|
||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||
@@ -197,13 +197,6 @@ class KVConnectorBase_V1(ABC):
|
||||
else:
|
||||
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
||||
self._kv_cache_config = kv_cache_config
|
||||
if self._kv_cache_config is None:
|
||||
logger.warning(
|
||||
"KVConnectorBase_V1 initialized without kv_cache_config. "
|
||||
"This is deprecated - please update your connector to accept "
|
||||
"kv_cache_config as the third constructor argument and pass it "
|
||||
"to super().__init__()."
|
||||
)
|
||||
self._role = role
|
||||
|
||||
@property
|
||||
|
||||
@@ -87,7 +87,7 @@ class DecodeBenchConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -120,7 +120,7 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@@ -476,7 +476,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
|
||||
@@ -342,7 +342,7 @@ class MooncakeConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
|
||||
@@ -93,9 +93,9 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config, role)
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
"kv_transfer_config must be set for MoRIIOConnector"
|
||||
)
|
||||
|
||||
@@ -52,11 +52,10 @@ class OffloadingConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert kv_cache_config is not None
|
||||
spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config)
|
||||
|
||||
self.connector_scheduler: OffloadingConnectorScheduler | None = None
|
||||
|
||||
@@ -76,7 +76,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@@ -49,7 +49,7 @@ class SimpleCPUOffloadConnector(KVConnectorBase_V1, SupportsHMA):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo
|
||||
|
||||
|
||||
def ensure_kv_transfer_initialized(
|
||||
vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig | None" = None
|
||||
vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig"
|
||||
) -> None:
|
||||
"""
|
||||
Initialize KV cache transfer parallel group.
|
||||
|
||||
Reference in New Issue
Block a user