[KV Connector] Remove compat support for pre-v0.12.0 constructor signatures without KVCacheConfig (#39832)

The v0.12.0 release contained initial support for HMA in KV Connectors. As part
of these changes, a KVCacheConfig argument was added to KV connector
constructors. Backwards compatibility support for out-of-tree connectors was
included in this change, with a very prominent warning. See #25712 and #27887.

Since the warning has been around for over 5 months, we can safely remove
the support of it.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-05-09 19:39:46 -04:00
committed by GitHub
parent f80aa53c9d
commit ea0e501bb1
22 changed files with 119 additions and 341 deletions
@@ -20,6 +20,7 @@ from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
@@ -35,8 +36,17 @@ class LoadRecoveryExampleConnectorMetadata(ExampleConnectorMetadata):
class LoadRecoveryExampleConnector(ExampleConnector):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
"async_load", False
)
+1 -1
View File
@@ -66,7 +66,7 @@ class DummyKVConnector(KVConnectorBase_V1):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
kv_cache_config: KVCacheConfig,
):
super().__init__(vllm_config, role, kv_cache_config)
# Get the status file path from extra config
@@ -1,275 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for backwards compatibility with external KV connector implementations.
This test ensures that external connectors (loaded via kv_connector_module_path)
implemented with the old signature continue to work:
- Old signature: __init__(self, vllm_config, role)
- New signature: __init__(self, vllm_config, role, kv_cache_config)
"""
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
class OldStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the old signature with 2 required arguments.
This simulates external connectors that haven't been updated yet.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
# Old-style call to super().__init__ with only 2 arguments
super().__init__(vllm_config=vllm_config, role=role)
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass
def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
def wait_for_save(self):
pass
class NewStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the new signature with 3 required arguments.
"""
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
# New-style call to super().__init__ with all 3 arguments
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass
def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
def wait_for_save(self):
pass
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_old_signature_factory_instantiation(role):
"""
Test that external connectors with old signature (2 required args) loaded
via kv_connector_module_path are correctly instantiated with backwards
compatibility support.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
assert connector is not None
assert isinstance(connector, OldStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is None
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_new_signature_factory_instantiation(role):
"""
Test that external connectors with new signature (3 required args) loaded
via kv_connector_module_path are correctly instantiated.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
assert connector is not None
assert isinstance(connector, NewStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_old_signature_super_init(role):
"""
Test that old-style connectors can call super().__init__() without
kv_cache_config parameter.
"""
vllm_config = create_vllm_config()
connector = OldStyleTestConnector(vllm_config, role)
assert connector is not None
assert connector.role == role
assert connector._kv_cache_config is None
def test_old_signature_super_init_with_kwargs():
"""
Test that old-style connectors can call super().__init__() with keyword
arguments in different orders.
"""
vllm_config = create_vllm_config()
# Test with vllm_config= and role= kwargs
connector1 = OldStyleTestConnector(
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
)
assert connector1 is not None
assert connector1._kv_cache_config is None
# Test with role= and vllm_config= in reversed order
connector2 = OldStyleTestConnector(
role=KVConnectorRole.WORKER, vllm_config=vllm_config
)
assert connector2 is not None
assert connector2._kv_cache_config is None
def test_internal_connector_uses_new_signature():
"""
Test that internal connectors (registered in factory) always use the new
signature and get kv_cache_config.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
ExampleConnector,
)
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "ExampleConnector"
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert connector is not None
assert isinstance(connector, ExampleConnector)
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config
def test_signature_detection_with_mocking():
"""
Test that the factory correctly applies compat_sig flag returned from
_get_connector_class_with_compat.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
# Mock _get_connector_class_with_compat to return old-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(OldStyleTestConnector, True),
):
old_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert old_connector is not None
assert isinstance(old_connector, OldStyleTestConnector)
assert old_connector._kv_cache_config is None
# Mock _get_connector_class_with_compat to return new-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(NewStyleTestConnector, False),
):
new_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert new_connector is not None
assert isinstance(new_connector, NewStyleTestConnector)
assert new_connector._kv_cache_config is not None
assert new_connector._kv_cache_config == kv_cache_config
@@ -58,7 +58,9 @@ class DecodeBenchTestRunner:
# Create worker-side connector
self.worker_connector = DecodeBenchConnector(
vllm_config, KVConnectorRole.WORKER
vllm_config,
KVConnectorRole.WORKER,
self.scheduler.kv_cache_config,
)
# Create dummy KV caches for testing
@@ -10,6 +10,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
get_kv_transfer_group,
)
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
# Importing utils registers TestExampleConnector with the factory
@@ -38,7 +39,10 @@ def test_kv_connector_mixin_clears_metadata():
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
# Initialize the global connector instance
ensure_kv_transfer_initialized(vllm_config)
kv_cache_config = KVCacheConfig(
num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[]
)
ensure_kv_transfer_initialized(vllm_config, kv_cache_config)
try:
# Minimal scheduler output with empty metadata; mixin should still
@@ -26,11 +26,16 @@ from vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_utils import
)
from vllm.utils.network_utils import get_open_port
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import RequestStatus
from .utils import create_request, create_scheduler, create_vllm_config
def _make_test_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[])
class FakeMooncakeWrapper:
"""Mock Mooncake TransferEngine for unit testing environments."""
@@ -321,7 +326,11 @@ async def test_kv_producer(monkeypatch):
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
prefill_connector = MooncakeConnector(
vllm_config,
KVConnectorRole.WORKER,
_make_test_kv_cache_config(),
)
prefill_worker = prefill_connector.connector_worker
prefill_worker.kv_caches_base_addr = [0x1000]
block_len = 4096
@@ -473,7 +482,11 @@ async def test_kv_consumuer(monkeypatch):
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies() as mocks:
decode_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
decode_connector = MooncakeConnector(
vllm_config,
KVConnectorRole.WORKER,
_make_test_kv_cache_config(),
)
decode_worker = decode_connector.connector_worker
decode_worker.kv_caches_base_addr = [0x1000]
decode_worker.rpc_port = 54321
@@ -533,7 +546,11 @@ async def test_worker_get_finished_timeout(monkeypatch):
kv_connector="MooncakeConnector", kv_role="kv_producer"
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
prefill_connector = MooncakeConnector(
vllm_config,
KVConnectorRole.WORKER,
_make_test_kv_cache_config(),
)
prefill_worker = prefill_connector.connector_worker
# Add an expired request (expire_time is in the past).
@@ -579,7 +596,11 @@ def test_register_kv_caches():
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Thread"
) as mock_thread,
):
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
connector = MooncakeConnector(
vllm_config,
KVConnectorRole.WORKER,
_make_test_kv_cache_config(),
)
worker = connector.connector_worker
mock_thread.return_value.is_alive.return_value = False
@@ -628,7 +649,11 @@ def test_register_kv_caches_supports_mixed_mla_and_eagle_shapes():
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector.threading.Thread"
) as mock_thread,
):
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
connector = MooncakeConnector(
vllm_config,
KVConnectorRole.WORKER,
_make_test_kv_cache_config(),
)
worker = connector.connector_worker
mock_thread.return_value.is_alive.return_value = False
@@ -688,7 +713,11 @@ async def test_kv_producer_heterogeneous_tp(monkeypatch, d_tp_size):
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
prefill_connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
prefill_connector = MooncakeConnector(
vllm_config,
KVConnectorRole.WORKER,
_make_test_kv_cache_config(),
)
prefill_worker = prefill_connector.connector_worker
# Override TP rank/size to simulate P TP=2
@@ -221,9 +221,14 @@ async def test_build_transfer_params_multi_group_trimming(monkeypatch):
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector", kv_role="kv_producer"
)
kv_cache_config = make_kv_cache_config(
block_size=vllm_config.cache_config.block_size, swa_enabled=True
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
connector = MooncakeConnector(
vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
worker = connector.connector_worker
block_len = 4096
@@ -304,9 +309,14 @@ async def test_build_transfer_params_group_count_mismatch(monkeypatch):
vllm_config = create_vllm_config(
kv_connector="MooncakeConnector", kv_role="kv_producer"
)
kv_cache_config = make_kv_cache_config(
block_size=vllm_config.cache_config.block_size, swa_enabled=True
)
with set_current_vllm_config(vllm_config), patch_worker_dependencies():
connector = MooncakeConnector(vllm_config, KVConnectorRole.WORKER)
connector = MooncakeConnector(
vllm_config, KVConnectorRole.WORKER, kv_cache_config
)
worker = connector.connector_worker
block_len = 4096
@@ -37,9 +37,15 @@ from vllm.utils.network_utils import (
get_ip,
make_zmq_path,
)
from vllm.v1.kv_cache_interface import KVCacheConfig
from .utils import create_request, create_scheduler
def _make_test_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[])
aiter_available = importlib.util.find_spec("aiter") is not None
mori_available = importlib.util.find_spec("mori") is not None
@@ -462,7 +468,11 @@ def test_register_kv_caches(mock_parallel_groups):
)
with set_current_vllm_config(vllm_config):
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
connector = MoRIIOConnector(
vllm_config,
KVConnectorRole.WORKER,
_make_test_kv_cache_config(),
)
connector.connector_worker = FakeMoRIIOConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
@@ -554,7 +564,11 @@ def test_moriio_handshake_returns_metadata(mock_parallel_groups):
}
)
with set_current_vllm_config(vllm_config):
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
connector = MoRIIOConnector(
vllm_config,
KVConnectorRole.WORKER,
_make_test_kv_cache_config(),
)
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
@@ -17,6 +17,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
@@ -26,7 +27,7 @@ class DummyConnectorMetadata(KVConnectorMetadata):
class DummyKVConnector(KVConnectorBase_V1):
def __init__(self, vllm_config, role, kv_cache_config=None):
def __init__(self, vllm_config, role, kv_cache_config: KVCacheConfig):
super().__init__(vllm_config, role, kv_cache_config)
def get_num_new_matched_tokens(
+8 -3
View File
@@ -293,9 +293,14 @@ def create_model_runner_output(
class TestExampleConnector(ExampleConnector):
def __init__(self, config: VllmConfig, role, kv_cache_config):
def __init__(
self,
config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig,
):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = ExampleConnector(config, role)
self._connector = ExampleConnector(config, role, kv_cache_config)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = (
@@ -368,7 +373,7 @@ class MockKVConnector(KVConnectorBase_V1):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
kv_cache_config: KVCacheConfig,
):
super().__init__(vllm_config, role, kv_cache_config)
extra_config = self._kv_transfer_config.kv_connector_extra_config
@@ -44,14 +44,12 @@ class KVConnectorFactory:
cls,
config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
) -> KVConnectorBase:
kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls, compat_sig = cls._get_connector_class_with_compat(
kv_transfer_config
)
connector_cls = cls.get_connector_class(kv_transfer_config)
# check if the connector supports HMA
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
@@ -74,12 +72,7 @@ class KVConnectorFactory:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
if compat_sig:
# Old signature: __init__(self, vllm_config, role)
return connector_cls(config, role)
else:
# New signature: __init__(self, vllm_config, role, kv_cache_config)
return connector_cls(config, role, kv_cache_config)
return connector_cls(config, role, kv_cache_config)
@classmethod
def get_connector_class_by_name(
@@ -100,13 +93,12 @@ class KVConnectorFactory:
return cls._registry[connector_name]()
@classmethod
def _get_connector_class_with_compat(
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> tuple[type[KVConnectorBaseType], bool]:
) -> type[KVConnectorBaseType]:
connector_name = kv_transfer_config.kv_connector
if connector_name is None:
raise ValueError("Connector name is not set in KVTransferConfig")
compat_sig = False
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is not None and not connector_module_path:
raise ValueError("kv_connector_module_path cannot be an empty string.")
@@ -121,24 +113,18 @@ class KVConnectorFactory:
) from e
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
if not supports_kw(connector_cls, "kv_cache_config"):
compat_sig = True
logger.warning(
"Connector %s uses deprecated signature with 2 required arguments. "
"Please update to include kv_cache_config as the second argument.",
connector_cls.__name__,
msg = (
f"Connector {connector_cls.__name__} uses deprecated "
"2-argument constructor signature. External v1 KV "
"connectors must accept kv_cache_config as the third "
"constructor argument and pass it to super().__init__()."
)
logger.error(msg)
raise ValueError(msg)
elif connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
raise ValueError(f"Unsupported connector type: {connector_name}")
return connector_cls, compat_sig
@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
return connector_cls
@@ -184,7 +184,7 @@ class KVConnectorBase_V1(ABC):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
@@ -197,13 +197,6 @@ class KVConnectorBase_V1(ABC):
else:
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
self._kv_cache_config = kv_cache_config
if self._kv_cache_config is None:
logger.warning(
"KVConnectorBase_V1 initialized without kv_cache_config. "
"This is deprecated - please update your connector to accept "
"kv_cache_config as the third constructor argument and pass it "
"to super().__init__()."
)
self._role = role
@property
@@ -87,7 +87,7 @@ class DecodeBenchConnector(KVConnectorBase_V1, SupportsHMA):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
@@ -92,7 +92,7 @@ class ExampleConnector(KVConnectorBase_V1):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config,
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
import safetensors
import torch
@@ -120,7 +120,7 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config,
@@ -476,7 +476,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
@@ -342,7 +342,7 @@ class MooncakeConnector(KVConnectorBase_V1, SupportsHMA):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
@@ -93,9 +93,9 @@ class MoRIIOConnector(KVConnectorBase_V1):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role)
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
@@ -52,11 +52,10 @@ class OffloadingConnector(KVConnectorBase_V1, SupportsHMA):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
kv_cache_config: KVCacheConfig,
):
super().__init__(vllm_config, role, kv_cache_config)
assert kv_cache_config is not None
spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config)
self.connector_scheduler: OffloadingConnectorScheduler | None = None
@@ -76,7 +76,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config,
@@ -49,7 +49,7 @@ class SimpleCPUOffloadConnector(KVConnectorBase_V1, SupportsHMA):
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
kv_cache_config: "KVCacheConfig",
):
super().__init__(vllm_config, role, kv_cache_config)
@@ -49,7 +49,7 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo
def ensure_kv_transfer_initialized(
vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig | None" = None
vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig"
) -> None:
"""
Initialize KV cache transfer parallel group.