[Feature] SSL support for dp supervisor (#43688)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-05-29 15:28:12 -04:00
committed by GitHub
parent acbc203340
commit 5dbf1605a0
2 changed files with 143 additions and 22 deletions
+125 -18
View File
@@ -21,7 +21,10 @@ import asyncio
import contextlib
import os
import signal
import subprocess
import tempfile
import time
from pathlib import Path
from types import SimpleNamespace
import aiohttp
@@ -35,6 +38,7 @@ from vllm.entrypoints.openai.dp_supervisor import (
DPSupervisor,
_build_vllm_dp_server_args,
infer_multi_port_external_lb_start_rank,
validate_multi_port_external_lb_args,
)
from vllm.logger import init_logger
@@ -75,6 +79,8 @@ def _make_unit_args(**overrides) -> argparse.Namespace:
"ssl_keyfile": None,
"ssl_certfile": None,
"ssl_ca_certs": None,
"ssl_cert_reqs": 0,
"ssl_ciphers": None,
"node_rank": 1,
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
@@ -108,6 +114,8 @@ def _make_args(**overrides) -> argparse.Namespace:
ssl_keyfile=None,
ssl_certfile=None,
ssl_ca_certs=None,
ssl_cert_reqs=0,
ssl_ciphers=None,
node_rank=0,
tensor_parallel_size=1,
pipeline_parallel_size=1,
@@ -118,6 +126,33 @@ def _make_args(**overrides) -> argparse.Namespace:
return argparse.Namespace(**base)
def _generate_self_signed_cert(cert_dir: Path) -> tuple[Path, Path]:
"""Generate a self-signed certificate for HTTPS lifecycle tests."""
cert_file = cert_dir / "cert.pem"
key_file = cert_dir / "key.pem"
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:2048",
"-keyout",
str(key_file),
"-out",
str(cert_file),
"-days",
"1",
"-nodes",
"-subj",
"/CN=localhost",
],
check=True,
capture_output=True,
)
return cert_file, key_file
# ---------------------------------------------------------------------------
# Unit tests
# ---------------------------------------------------------------------------
@@ -141,6 +176,15 @@ def test_build_multi_port_external_lb_child_args_sets_external_rank_server():
assert child_args.api_server_count == 1
def test_validate_multi_port_external_lb_args_allows_ssl():
args = _make_unit_args(
ssl_keyfile="/tmp/server.key",
ssl_certfile="/tmp/server.crt",
ssl_ca_certs="/tmp/ca.crt",
)
validate_multi_port_external_lb_args(args)
def test_aggregates_health():
supervisor = DPSupervisor(_make_unit_args())
supervisor._is_ready = True
@@ -236,10 +280,18 @@ class MockVLLMServer:
Health state is toggled by the test via set_healthy().
"""
def __init__(self, port: int, drain_seconds: float = 0.0) -> None:
def __init__(
self,
port: int,
drain_seconds: float = 0.0,
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
) -> None:
self.port = port
self._healthy = False
self._drain_seconds = drain_seconds
self._ssl_keyfile = ssl_keyfile
self._ssl_certfile = ssl_certfile
self._server: uvicorn.Server | None = None
self._serve_task: asyncio.Task | None = None
@@ -274,6 +326,8 @@ class MockVLLMServer:
port=self.port,
log_level="warning",
lifespan="off",
ssl_keyfile=self._ssl_keyfile,
ssl_certfile=self._ssl_certfile,
)
self._server = uvicorn.Server(config)
@@ -312,7 +366,11 @@ class MockVLLMServer:
def launch_mock_vllm(child_args: argparse.Namespace, env_updates: dict[str, str]):
logger.info("Launching mock vLLM on port %s", child_args.port)
mock_vllm = MockVLLMServer(port=child_args.port)
mock_vllm = MockVLLMServer(
port=child_args.port,
ssl_keyfile=child_args.ssl_keyfile,
ssl_certfile=child_args.ssl_certfile,
)
asyncio.run(mock_vllm.start())
@@ -320,7 +378,12 @@ def launch_mock_vllm_with_drain(
child_args: argparse.Namespace, env_updates: dict[str, str]
):
logger.info("Launching mock vLLM with 15s drain on port %s", child_args.port)
mock_vllm = MockVLLMServer(port=child_args.port, drain_seconds=10.0)
mock_vllm = MockVLLMServer(
port=child_args.port,
drain_seconds=10.0,
ssl_keyfile=child_args.ssl_keyfile,
ssl_certfile=child_args.ssl_certfile,
)
asyncio.run(mock_vllm.start())
@@ -329,15 +392,16 @@ def launch_mock_vllm_with_drain(
# ---------------------------------------------------------------------------
async def _poll_supervisor_health(expected_status: int) -> bool:
async def _poll_supervisor_health(expected_status: int, use_ssl: bool = False) -> bool:
"""
Poll GET /health on the supervisor until expected_status is seen.
A connection error is treated as 503-equivalent when expected_status != 200.
"""
url = f"http://127.0.0.1:{_SUPERVISOR_PORT}/health"
scheme = "https" if use_ssl else "http"
url = f"{scheme}://127.0.0.1:{_SUPERVISOR_PORT}/health"
async with aiohttp.ClientSession() as session:
try:
async with session.get(url) as resp:
async with session.get(url, ssl=False if use_ssl else None) as resp:
if resp.status != expected_status:
print(f"expected: {expected_status=}, got: {resp.status=}")
return False
@@ -349,12 +413,15 @@ async def _poll_supervisor_health(expected_status: int) -> bool:
return True
async def _poll_until_api_server_running(port: int, retries: int = 10) -> None:
url = f"http://127.0.0.1:{port}/health"
async def _poll_until_api_server_running(
port: int, retries: int = 10, use_ssl: bool = False
) -> None:
scheme = "https" if use_ssl else "http"
url = f"{scheme}://127.0.0.1:{port}/health"
async with aiohttp.ClientSession() as session:
for _ in range(retries):
try:
async with session.get(url) as resp:
async with session.get(url, ssl=False if use_ssl else None) as resp:
if resp.status != 200:
return
await asyncio.sleep(1.0)
@@ -363,22 +430,34 @@ async def _poll_until_api_server_running(port: int, retries: int = 10) -> None:
await asyncio.sleep(1.0)
async def _set_healthy(port: int) -> None:
url = f"http://127.0.0.1:{port}/set_healthy"
async with aiohttp.ClientSession() as session, session.get(url) as resp:
async def _set_healthy(port: int, use_ssl: bool = False) -> None:
scheme = "https" if use_ssl else "http"
url = f"{scheme}://127.0.0.1:{port}/set_healthy"
async with (
aiohttp.ClientSession() as session,
session.get(url, ssl=False if use_ssl else None) as resp,
):
assert resp.status == 200
async def _set_unhealthy(port: int) -> None:
url = f"http://127.0.0.1:{port}/set_unhealthy"
async with aiohttp.ClientSession() as session, session.get(url) as resp:
async def _set_unhealthy(port: int, use_ssl: bool = False) -> None:
scheme = "https" if use_ssl else "http"
url = f"{scheme}://127.0.0.1:{port}/set_unhealthy"
async with (
aiohttp.ClientSession() as session,
session.get(url, ssl=False if use_ssl else None) as resp,
):
assert resp.status == 200
async def _kill_server(port: int) -> None:
url = f"http://127.0.0.1:{port}/kill"
async def _kill_server(port: int, use_ssl: bool = False) -> None:
scheme = "https" if use_ssl else "http"
url = f"{scheme}://127.0.0.1:{port}/kill"
try:
async with aiohttp.ClientSession() as session, session.get(url) as resp:
async with (
aiohttp.ClientSession() as session,
session.get(url, ssl=False if use_ssl else None) as resp,
):
assert resp.status != 200
except Exception as e:
assert isinstance(e, aiohttp.ClientConnectorError)
@@ -455,6 +534,34 @@ async def test_basic_lifecycle(monkeypatch):
print("everything was cleaned up!")
@pytest.mark.asyncio
async def test_basic_lifecycle_with_ssl(monkeypatch):
with tempfile.TemporaryDirectory() as cert_dir:
cert_file, key_file = _generate_self_signed_cert(Path(cert_dir))
args = _make_args(
ssl_keyfile=str(key_file),
ssl_certfile=str(cert_file),
)
vllm_server_ports = [_CHILD_PORT_BASE + i for i in range(_N_CHILDREN)]
async with _run_supervisor(args, monkeypatch) as (supervisor, _task):
assert await _poll_supervisor_health(503, use_ssl=True)
assert not supervisor.is_ready
for port in vllm_server_ports:
assert await _poll_supervisor_health(503, use_ssl=True)
assert not supervisor.is_ready
await _poll_until_api_server_running(port, use_ssl=True)
for port in vllm_server_ports:
await _set_healthy(port, use_ssl=True)
await asyncio.sleep(1.0)
assert await _poll_supervisor_health(200, use_ssl=True)
assert supervisor.is_ready
@pytest.mark.asyncio
async def test_failed_startup(monkeypatch):
"""
+18 -4
View File
@@ -55,9 +55,9 @@ def validate_multi_port_external_lb_args(args: argparse.Namespace) -> None:
raise ValueError(
"Error: --data-parallel-multi-port-external-lb does not support --uds"
)
if any((args.ssl_keyfile, args.ssl_certfile, args.ssl_ca_certs)):
if bool(args.ssl_keyfile) != bool(args.ssl_certfile):
raise ValueError(
"Error: --data-parallel-multi-port-external-lb does not support HTTPS yet"
"Error: --ssl-keyfile and --ssl-certfile must be provided together"
)
if args.api_server_count not in (None, 1):
raise ValueError(
@@ -151,7 +151,8 @@ def _child_base_url(args: argparse.Namespace, port: int) -> str:
host = "127.0.0.1"
elif host == "::":
host = "::1"
return f"http://{host}:{port}"
scheme = "https" if args.ssl_keyfile and args.ssl_certfile else "http"
return f"{scheme}://{host}:{port}"
def _join_processes_with_timeout(processes: list[BaseProcess], timeout: float) -> None:
@@ -178,7 +179,15 @@ async def _probe_endpoint(
"""
for iteration in range(conn_err_failure_threshold):
try:
async with session.get(_child_base_url(args, port) + path) as response:
probe_ssl = None
if args.ssl_keyfile and args.ssl_certfile:
# Probes target node-local child servers over loopback, so skip
# certificate verification to avoid SAN/hostname mismatches for
# localhost/127.0.0.1 deployments.
probe_ssl = False
async with session.get(
_child_base_url(args, port) + path, ssl=probe_ssl
) as response:
# vLLM returns 503 on EngineDeadError, so we should return
# immediately if vLLM responds with a non-200 status code.
return response.status == HTTPStatus.OK
@@ -272,6 +281,11 @@ class DPSupervisor:
host=host,
port=self.supervisor_port,
log_level=self.args.uvicorn_log_level,
ssl_keyfile=self.args.ssl_keyfile,
ssl_certfile=self.args.ssl_certfile,
ssl_ca_certs=self.args.ssl_ca_certs,
ssl_cert_reqs=self.args.ssl_cert_reqs,
ssl_ciphers=self.args.ssl_ciphers,
)
supervisor_server = uvicorn.Server(config)
supervisor_server_task = asyncio.create_task(