mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Feature] SSL support for dp supervisor (#43688)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -21,7 +21,10 @@ import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import aiohttp
|
||||
@@ -35,6 +38,7 @@ from vllm.entrypoints.openai.dp_supervisor import (
|
||||
DPSupervisor,
|
||||
_build_vllm_dp_server_args,
|
||||
infer_multi_port_external_lb_start_rank,
|
||||
validate_multi_port_external_lb_args,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -75,6 +79,8 @@ def _make_unit_args(**overrides) -> argparse.Namespace:
|
||||
"ssl_keyfile": None,
|
||||
"ssl_certfile": None,
|
||||
"ssl_ca_certs": None,
|
||||
"ssl_cert_reqs": 0,
|
||||
"ssl_ciphers": None,
|
||||
"node_rank": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"pipeline_parallel_size": 1,
|
||||
@@ -108,6 +114,8 @@ def _make_args(**overrides) -> argparse.Namespace:
|
||||
ssl_keyfile=None,
|
||||
ssl_certfile=None,
|
||||
ssl_ca_certs=None,
|
||||
ssl_cert_reqs=0,
|
||||
ssl_ciphers=None,
|
||||
node_rank=0,
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
@@ -118,6 +126,33 @@ def _make_args(**overrides) -> argparse.Namespace:
|
||||
return argparse.Namespace(**base)
|
||||
|
||||
|
||||
def _generate_self_signed_cert(cert_dir: Path) -> tuple[Path, Path]:
|
||||
"""Generate a self-signed certificate for HTTPS lifecycle tests."""
|
||||
cert_file = cert_dir / "cert.pem"
|
||||
key_file = cert_dir / "key.pem"
|
||||
subprocess.run(
|
||||
[
|
||||
"openssl",
|
||||
"req",
|
||||
"-x509",
|
||||
"-newkey",
|
||||
"rsa:2048",
|
||||
"-keyout",
|
||||
str(key_file),
|
||||
"-out",
|
||||
str(cert_file),
|
||||
"-days",
|
||||
"1",
|
||||
"-nodes",
|
||||
"-subj",
|
||||
"/CN=localhost",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
return cert_file, key_file
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,6 +176,15 @@ def test_build_multi_port_external_lb_child_args_sets_external_rank_server():
|
||||
assert child_args.api_server_count == 1
|
||||
|
||||
|
||||
def test_validate_multi_port_external_lb_args_allows_ssl():
|
||||
args = _make_unit_args(
|
||||
ssl_keyfile="/tmp/server.key",
|
||||
ssl_certfile="/tmp/server.crt",
|
||||
ssl_ca_certs="/tmp/ca.crt",
|
||||
)
|
||||
validate_multi_port_external_lb_args(args)
|
||||
|
||||
|
||||
def test_aggregates_health():
|
||||
supervisor = DPSupervisor(_make_unit_args())
|
||||
supervisor._is_ready = True
|
||||
@@ -236,10 +280,18 @@ class MockVLLMServer:
|
||||
Health state is toggled by the test via set_healthy().
|
||||
"""
|
||||
|
||||
def __init__(self, port: int, drain_seconds: float = 0.0) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
port: int,
|
||||
drain_seconds: float = 0.0,
|
||||
ssl_keyfile: str | None = None,
|
||||
ssl_certfile: str | None = None,
|
||||
) -> None:
|
||||
self.port = port
|
||||
self._healthy = False
|
||||
self._drain_seconds = drain_seconds
|
||||
self._ssl_keyfile = ssl_keyfile
|
||||
self._ssl_certfile = ssl_certfile
|
||||
self._server: uvicorn.Server | None = None
|
||||
self._serve_task: asyncio.Task | None = None
|
||||
|
||||
@@ -274,6 +326,8 @@ class MockVLLMServer:
|
||||
port=self.port,
|
||||
log_level="warning",
|
||||
lifespan="off",
|
||||
ssl_keyfile=self._ssl_keyfile,
|
||||
ssl_certfile=self._ssl_certfile,
|
||||
)
|
||||
self._server = uvicorn.Server(config)
|
||||
|
||||
@@ -312,7 +366,11 @@ class MockVLLMServer:
|
||||
|
||||
def launch_mock_vllm(child_args: argparse.Namespace, env_updates: dict[str, str]):
|
||||
logger.info("Launching mock vLLM on port %s", child_args.port)
|
||||
mock_vllm = MockVLLMServer(port=child_args.port)
|
||||
mock_vllm = MockVLLMServer(
|
||||
port=child_args.port,
|
||||
ssl_keyfile=child_args.ssl_keyfile,
|
||||
ssl_certfile=child_args.ssl_certfile,
|
||||
)
|
||||
asyncio.run(mock_vllm.start())
|
||||
|
||||
|
||||
@@ -320,7 +378,12 @@ def launch_mock_vllm_with_drain(
|
||||
child_args: argparse.Namespace, env_updates: dict[str, str]
|
||||
):
|
||||
logger.info("Launching mock vLLM with 15s drain on port %s", child_args.port)
|
||||
mock_vllm = MockVLLMServer(port=child_args.port, drain_seconds=10.0)
|
||||
mock_vllm = MockVLLMServer(
|
||||
port=child_args.port,
|
||||
drain_seconds=10.0,
|
||||
ssl_keyfile=child_args.ssl_keyfile,
|
||||
ssl_certfile=child_args.ssl_certfile,
|
||||
)
|
||||
asyncio.run(mock_vllm.start())
|
||||
|
||||
|
||||
@@ -329,15 +392,16 @@ def launch_mock_vllm_with_drain(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _poll_supervisor_health(expected_status: int) -> bool:
|
||||
async def _poll_supervisor_health(expected_status: int, use_ssl: bool = False) -> bool:
|
||||
"""
|
||||
Poll GET /health on the supervisor until expected_status is seen.
|
||||
A connection error is treated as 503-equivalent when expected_status != 200.
|
||||
"""
|
||||
url = f"http://127.0.0.1:{_SUPERVISOR_PORT}/health"
|
||||
scheme = "https" if use_ssl else "http"
|
||||
url = f"{scheme}://127.0.0.1:{_SUPERVISOR_PORT}/health"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
async with session.get(url, ssl=False if use_ssl else None) as resp:
|
||||
if resp.status != expected_status:
|
||||
print(f"expected: {expected_status=}, got: {resp.status=}")
|
||||
return False
|
||||
@@ -349,12 +413,15 @@ async def _poll_supervisor_health(expected_status: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def _poll_until_api_server_running(port: int, retries: int = 10) -> None:
|
||||
url = f"http://127.0.0.1:{port}/health"
|
||||
async def _poll_until_api_server_running(
|
||||
port: int, retries: int = 10, use_ssl: bool = False
|
||||
) -> None:
|
||||
scheme = "https" if use_ssl else "http"
|
||||
url = f"{scheme}://127.0.0.1:{port}/health"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for _ in range(retries):
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
async with session.get(url, ssl=False if use_ssl else None) as resp:
|
||||
if resp.status != 200:
|
||||
return
|
||||
await asyncio.sleep(1.0)
|
||||
@@ -363,22 +430,34 @@ async def _poll_until_api_server_running(port: int, retries: int = 10) -> None:
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
|
||||
async def _set_healthy(port: int) -> None:
|
||||
url = f"http://127.0.0.1:{port}/set_healthy"
|
||||
async with aiohttp.ClientSession() as session, session.get(url) as resp:
|
||||
async def _set_healthy(port: int, use_ssl: bool = False) -> None:
|
||||
scheme = "https" if use_ssl else "http"
|
||||
url = f"{scheme}://127.0.0.1:{port}/set_healthy"
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(url, ssl=False if use_ssl else None) as resp,
|
||||
):
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
async def _set_unhealthy(port: int) -> None:
|
||||
url = f"http://127.0.0.1:{port}/set_unhealthy"
|
||||
async with aiohttp.ClientSession() as session, session.get(url) as resp:
|
||||
async def _set_unhealthy(port: int, use_ssl: bool = False) -> None:
|
||||
scheme = "https" if use_ssl else "http"
|
||||
url = f"{scheme}://127.0.0.1:{port}/set_unhealthy"
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(url, ssl=False if use_ssl else None) as resp,
|
||||
):
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
async def _kill_server(port: int) -> None:
|
||||
url = f"http://127.0.0.1:{port}/kill"
|
||||
async def _kill_server(port: int, use_ssl: bool = False) -> None:
|
||||
scheme = "https" if use_ssl else "http"
|
||||
url = f"{scheme}://127.0.0.1:{port}/kill"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session, session.get(url) as resp:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.get(url, ssl=False if use_ssl else None) as resp,
|
||||
):
|
||||
assert resp.status != 200
|
||||
except Exception as e:
|
||||
assert isinstance(e, aiohttp.ClientConnectorError)
|
||||
@@ -455,6 +534,34 @@ async def test_basic_lifecycle(monkeypatch):
|
||||
print("everything was cleaned up!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_lifecycle_with_ssl(monkeypatch):
|
||||
with tempfile.TemporaryDirectory() as cert_dir:
|
||||
cert_file, key_file = _generate_self_signed_cert(Path(cert_dir))
|
||||
args = _make_args(
|
||||
ssl_keyfile=str(key_file),
|
||||
ssl_certfile=str(cert_file),
|
||||
)
|
||||
|
||||
vllm_server_ports = [_CHILD_PORT_BASE + i for i in range(_N_CHILDREN)]
|
||||
|
||||
async with _run_supervisor(args, monkeypatch) as (supervisor, _task):
|
||||
assert await _poll_supervisor_health(503, use_ssl=True)
|
||||
assert not supervisor.is_ready
|
||||
|
||||
for port in vllm_server_ports:
|
||||
assert await _poll_supervisor_health(503, use_ssl=True)
|
||||
assert not supervisor.is_ready
|
||||
await _poll_until_api_server_running(port, use_ssl=True)
|
||||
|
||||
for port in vllm_server_ports:
|
||||
await _set_healthy(port, use_ssl=True)
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
assert await _poll_supervisor_health(200, use_ssl=True)
|
||||
assert supervisor.is_ready
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_startup(monkeypatch):
|
||||
"""
|
||||
|
||||
@@ -55,9 +55,9 @@ def validate_multi_port_external_lb_args(args: argparse.Namespace) -> None:
|
||||
raise ValueError(
|
||||
"Error: --data-parallel-multi-port-external-lb does not support --uds"
|
||||
)
|
||||
if any((args.ssl_keyfile, args.ssl_certfile, args.ssl_ca_certs)):
|
||||
if bool(args.ssl_keyfile) != bool(args.ssl_certfile):
|
||||
raise ValueError(
|
||||
"Error: --data-parallel-multi-port-external-lb does not support HTTPS yet"
|
||||
"Error: --ssl-keyfile and --ssl-certfile must be provided together"
|
||||
)
|
||||
if args.api_server_count not in (None, 1):
|
||||
raise ValueError(
|
||||
@@ -151,7 +151,8 @@ def _child_base_url(args: argparse.Namespace, port: int) -> str:
|
||||
host = "127.0.0.1"
|
||||
elif host == "::":
|
||||
host = "::1"
|
||||
return f"http://{host}:{port}"
|
||||
scheme = "https" if args.ssl_keyfile and args.ssl_certfile else "http"
|
||||
return f"{scheme}://{host}:{port}"
|
||||
|
||||
|
||||
def _join_processes_with_timeout(processes: list[BaseProcess], timeout: float) -> None:
|
||||
@@ -178,7 +179,15 @@ async def _probe_endpoint(
|
||||
"""
|
||||
for iteration in range(conn_err_failure_threshold):
|
||||
try:
|
||||
async with session.get(_child_base_url(args, port) + path) as response:
|
||||
probe_ssl = None
|
||||
if args.ssl_keyfile and args.ssl_certfile:
|
||||
# Probes target node-local child servers over loopback, so skip
|
||||
# certificate verification to avoid SAN/hostname mismatches for
|
||||
# localhost/127.0.0.1 deployments.
|
||||
probe_ssl = False
|
||||
async with session.get(
|
||||
_child_base_url(args, port) + path, ssl=probe_ssl
|
||||
) as response:
|
||||
# vLLM returns 503 on EngineDeadError, so we should return
|
||||
# immediately if vLLM responds with a non-200 status code.
|
||||
return response.status == HTTPStatus.OK
|
||||
@@ -272,6 +281,11 @@ class DPSupervisor:
|
||||
host=host,
|
||||
port=self.supervisor_port,
|
||||
log_level=self.args.uvicorn_log_level,
|
||||
ssl_keyfile=self.args.ssl_keyfile,
|
||||
ssl_certfile=self.args.ssl_certfile,
|
||||
ssl_ca_certs=self.args.ssl_ca_certs,
|
||||
ssl_cert_reqs=self.args.ssl_cert_reqs,
|
||||
ssl_ciphers=self.args.ssl_ciphers,
|
||||
)
|
||||
supervisor_server = uvicorn.Server(config)
|
||||
supervisor_server_task = asyncio.create_task(
|
||||
|
||||
Reference in New Issue
Block a user