diff --git a/docs/design/nixl_kv_cache_lease.md b/docs/design/nixl_kv_cache_lease.md index a3fdaafe345..aa7683bb9e1 100644 --- a/docs/design/nixl_kv_cache_lease.md +++ b/docs/design/nixl_kv_cache_lease.md @@ -128,7 +128,7 @@ The lease mechanism is controlled through `kv_connector_extra_config` in `--kv-t vllm serve \ --kv-transfer-config '{ "kv_connector": "NixlConnector", - "kv_role": "kv_both", + "kv_role": "kv_producer", "kv_connector_extra_config": {"kv_lease_duration": 60} }' ``` diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index cb5a3dca035..0f0cbd55354 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -50,7 +50,7 @@ To select a different backend, set `kv_connector_extra_config.backends` in `--kv vllm serve \ --kv-transfer-config '{ "kv_connector":"NixlConnector", - "kv_role":"kv_both", + "kv_role":"kv_producer", "kv_connector_extra_config":{"backends":["LIBFABRIC"]} }' ``` @@ -60,7 +60,7 @@ You can also pass JSON keys individually using dotted arguments, and you can app ```bash vllm serve \ --kv-transfer-config.kv_connector NixlConnector \ - --kv-transfer-config.kv_role kv_both \ + --kv-transfer-config.kv_role kv_producer \ --kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC ``` @@ -81,7 +81,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \ vllm serve Qwen/Qwen3-0.6B \ --port 8100 \ --enforce-eager \ - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}' + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer","kv_load_failure_policy":"fail"}' ``` ### Consumer (Decoder) Configuration @@ -96,7 +96,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \ vllm serve Qwen/Qwen3-0.6B \ --port 8200 \ --enforce-eager \ - --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}' + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer","kv_load_failure_policy":"fail"}' ``` ### Proxy Server @@ -212,10 +212,21 @@ sequenceDiagram Enable bidirectional KV transfer by setting `bidirectional_kv_xfer` in `kv_connector_extra_config` on **both** P and D instances: ```bash +# Prefill instance vllm serve \ --kv-transfer-config '{ "kv_connector": "NixlConnector", - "kv_role": "kv_both", + "kv_role": "kv_producer", + "kv_connector_extra_config": { + "bidirectional_kv_xfer": true + } + }' + +# Decode instance +vllm serve \ + --kv-transfer-config '{ + "kv_connector": "NixlConnector", + "kv_role": "kv_consumer", "kv_connector_extra_config": { "bidirectional_kv_xfer": true } @@ -359,11 +370,10 @@ For multi-host DP deployment, only need to provide the host/port of the head ins - **kv_producer**: For prefiller instances that generate KV caches - **kv_consumer**: For decoder instances that consume KV caches from prefiller -- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined. +- **kv_both** (deprecated): Previously used as a catch-all when the role was not predetermined. This value is now deprecated for NixlConnector and will be removed in a future release. -!!! tip - NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`). - Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior. +!!! warning + `kv_role="kv_both"` is deprecated for NixlConnector. Please set `kv_role="kv_producer"` for prefill instances and `kv_role="kv_consumer"` for decode instances. See [#33702](https://github.com/vllm-project/vllm/issues/33702) for details. ### KV Load Failure Policy diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index bde246c9b66..0e7f6af7e38 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -49,11 +49,13 @@ else KV_EXTRA_CONFIG='' fi -# Build the kv-transfer-config once +# Build the kv-transfer-config for P and D if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then - KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' + KV_CONFIG_P='{"kv_connector":"NixlConnector","kv_role":"kv_producer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' + KV_CONFIG_D='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}' else - KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" + KV_CONFIG_P="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" + KV_CONFIG_D="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}" fi # Models to run @@ -159,7 +161,7 @@ run_tests_for_model() { --block-size ${PREFILL_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ - --kv-transfer-config '$KV_CONFIG'" + --kv-transfer-config '$KV_CONFIG_P'" if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS" for arg in "${extra_args[@]}"; do @@ -207,7 +209,7 @@ run_tests_for_model() { --enforce-eager \ --block-size ${DECODE_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ - --kv-transfer-config '$KV_CONFIG'" + --kv-transfer-config '$KV_CONFIG_D'" if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS" for arg in "${extra_args[@]}"; do diff --git a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh index bc90680a533..10e119d48a9 100755 --- a/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh +++ b/tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test.sh @@ -193,9 +193,11 @@ run_test_for_device() { local kv_device=$1 if [[ "$kv_device" == "cuda" ]]; then - local kv_config='{"kv_connector":"NixlConnector","kv_role":"kv_both"}' + local kv_config_p='{"kv_connector":"NixlConnector","kv_role":"kv_producer"}' + local kv_config_d='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}' else - local kv_config="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${kv_device}\"}" + local kv_config_p="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"${kv_device}\"}" + local kv_config_d="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"${kv_device}\"}" fi echo "" @@ -248,7 +250,7 @@ run_test_for_device() { --block-size ${BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ - --kv-transfer-config "$kv_config" \ + --kv-transfer-config "$kv_config_p" \ --speculative-config "$PREFILL_SPEC_CONFIG" \ --attention-backend $ATTENTION_BACKEND \ ${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} & @@ -287,7 +289,7 @@ run_test_for_device() { --block-size ${BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $DECODER_TP_SIZE \ - --kv-transfer-config "$kv_config" \ + --kv-transfer-config "$kv_config_d" \ --speculative-config "$DECODE_SPEC_CONFIG" \ --attention-backend $ATTENTION_BACKEND \ ${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} & diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 09f29c5d3cb..f78037a1431 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -1058,7 +1058,7 @@ def test_multi_connector_mixed_hma_disables_hybrid_kv_cache(monkeypatch): "connectors": [ { "kv_connector": "NixlConnector", - "kv_role": "kv_both", + "kv_role": "kv_consumer", }, { "kv_connector": "MockConnector", diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 1a7c35cacb8..6f6d8b1ca98 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1363,7 +1363,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): timeout = 6 kv_transfer_config = KVTransferConfig( kv_connector="NixlConnector", - kv_role="kv_both", + kv_role="kv_consumer", kv_connector_extra_config={"kv_lease_duration": timeout}, ) llm_kwargs = { @@ -2737,3 +2737,50 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) f"got {notif!r} (expected {expected_notif!r}, " f"buggy form would be {bad_notif!r})" ) + + +def test_kv_both_deprecation_warning(default_vllm_config, dist_init): + """kv_role='kv_both' should emit a deprecation log warning.""" + from unittest.mock import patch + + from vllm.logger import _print_warning_once + + _print_warning_once.cache_clear() + + vllm_config = create_vllm_config(kv_role="kv_both") + + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger" + ) as mock_logger: + mock_logger.warning_once = mock_logger.warning_once + NixlConnector( + vllm_config, + KVConnectorRole.WORKER, + make_kv_cache_config(block_size=16), + ) + + mock_logger.warning_once.assert_called_once() + msg = mock_logger.warning_once.call_args[0][0] + assert "kv_role='kv_both'" in msg + assert "deprecated" in msg + + +def test_explicit_kv_role_no_deprecation_warning(default_vllm_config, dist_init): + """kv_role='kv_consumer' or 'kv_producer' should NOT emit a warning.""" + from unittest.mock import patch + + for role in ("kv_consumer", "kv_producer"): + vllm_config = create_vllm_config(kv_role=role) + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger" + ) as mock_logger: + NixlConnector( + vllm_config, + KVConnectorRole.WORKER, + make_kv_cache_config(block_size=16), + ) + + ( + mock_logger.warning_once.assert_not_called(), + (f"kv_role={role!r} should not emit deprecation warning"), + ) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py index 80088809469..6e399db7b14 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector_hma.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector_hma.py @@ -453,7 +453,7 @@ def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size): """ kv_transfer_config = KVTransferConfig( kv_connector="NixlConnector", - kv_role="kv_both", + kv_role="kv_consumer", ) block_size = 16 llm_kwargs = { diff --git a/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py b/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py index 3a3ef2a88a6..c3adc05e3ef 100644 --- a/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py +++ b/tests/v1/kv_connector/unit/test_rixl_gpu_mem_diag.py @@ -75,7 +75,7 @@ def test_gpu_memory_rixl_hma(model_name, sw_size): "gpu_memory_utilization": 0.5, "kv_transfer_config": KVTransferConfig( kv_connector="NixlConnector", - kv_role="kv_both", + kv_role="kv_consumer", ), "max_model_len": 2048, "disable_hybrid_kv_cache_manager": False, diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 1b892849d90..9d801f772a6 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -103,7 +103,7 @@ def create_vllm_config( kv_load_failure_policy: Literal["recompute", "fail"] = "fail", kv_connector: str = "NixlConnector", kv_connector_module_path: str | None = None, - kv_role: str = "kv_both", + kv_role: str = "kv_consumer", disable_hybrid_kv_cache_manager: bool | None = None, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py index 187322b4ae4..dad81e84c45 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/connector.py @@ -94,6 +94,15 @@ class NixlConnector(KVConnectorBase_V1, SupportsHMA): super().__init__(vllm_config, role, kv_cache_config) assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None + + if vllm_config.kv_transfer_config.kv_role == "kv_both": + logger.warning_once( + "Using kv_role='kv_both' with NixlConnector is deprecated " + "and will be removed in a future release. Please set " + "kv_role='kv_producer' for prefill instances and " + "kv_role='kv_consumer' for decode instances. " + ) + self.kv_cache_config = kv_cache_config self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.kv_transfer_config = vllm_config.kv_transfer_config