[TRTLLM-9181][feat] improve disagg-server prometheus metrics; synchronize workers' clocks when workers are dynamic (#9726)

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2025-12-16 21:16:32 +08:00 committed by GitHub
parent 609d1d0383
commit bd13957e70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 987 additions and 324 deletions

View File

@ -473,10 +473,20 @@ def dim_resolve_negative(dim, ndim):
return tuple(pos)
def get_free_port():
with socket.socket() as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
def get_free_port() -> int:
return get_free_ports(1)[0]
def get_free_ports(num=1) -> List[int]:
sockets = [
socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(num)
]
for s in sockets:
s.bind(('', 0))
ports = [s.getsockname()[1] for s in sockets]
for s in sockets:
s.close()
return ports
# mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here

View File

@ -2,6 +2,7 @@ import asyncio
import json
import os
import random
import socket
import time
from dataclasses import asdict, dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
@ -29,6 +30,18 @@ def get_worker_key(name: str, role: ServerRole, worker_id: str = "") -> str:
return f"{get_worker_key_prefix(name)}/{worker_id}"
def get_host_from_uri(uri: str) -> str:
return uri.split("://")[1].split(":")[0]
# Get the local ip address from a remote host,
# if remote host is not provided, use Google's public DNS server "8.8.8.8"
def get_local_ip(remote_host: str = "8.8.8.8") -> str:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect((remote_host, 80))
return s.getsockname()[0]
class DisaggClusterManager:
"""
The cluster manager is responsible for managing the workers in the cluster.
@ -238,18 +251,25 @@ class DisaggClusterWorker:
It will send heartbeat to the cluster storage every heartbeat_interval_sec seconds.
If the worker heartbeat fails, it will re-register itself.
"""
LOCALHOST_IPS = ["localhost", "127.0.0.1", "0.0.0.0", "::1",
"::"] # nosec B104
def __init__(self, role: ServerRole, host: str, port: int,
config: DisaggClusterConfig, storage: ClusterStorage):
self._role = role
self._host = host
self._port = port
self._config = config
self._cluster_storage = storage
self._stop = False
self._heartbeat_task = None
self._last_heartbeat = 0
self._worker_id = f"{role.name}-{host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
register_host = host
# if the host is localhost and the cluster uri is not localhost, use the hostname to register the worker
disagg_host = get_host_from_uri(self._config.cluster_uri)
if host in self.LOCALHOST_IPS and disagg_host not in self.LOCALHOST_IPS:
register_host = get_local_ip(disagg_host)
self._host = register_host
self._worker_id = f"{role.name}-{register_host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
def __del__(self):
try:

View File

@ -183,6 +183,9 @@ class OpenAIHttpClient(OpenAIClient):
yield response_dict
# finish the request after the successful response
await self._finish_request(request)
self._metrics_collector.complete_latency_seconds.observe(
get_steady_clock_now_in_seconds() - start_time
)
break # break and skip retries if the whole response is processed without exception
except (aiohttp.ClientError, OSError) as e:
if lines_yielded > 0:
@ -227,25 +230,24 @@ class OpenAIHttpClient(OpenAIClient):
i = 0
async for line in http_response.content.iter_any():
now_time = get_steady_clock_now_in_seconds()
if i == 0:
if hooks:
hooks.on_first_token(server, request)
self._metrics_collector.first_token_latency_seconds.observe(
now_time - last_token_time
)
else:
self._metrics_collector.per_token_latency_seconds.observe(
now_time - last_token_time
)
i += 1
if line:
if i == 0:
if hooks:
hooks.on_first_token(server, request)
self._metrics_collector.first_token_latency_seconds.observe(
now_time - last_token_time
)
else:
self._metrics_collector.per_token_latency_seconds.observe(
now_time - last_token_time
)
i += 1
yield line
await asyncio.sleep(0)
last_token_time = now_time
if hooks:
hooks.on_resp_done(server, request, None)
self._metrics_collector.completed_requests.inc()
self._metrics_collector.complete_latency_seconds.observe(
get_steady_clock_now_in_seconds() - start_time
)
@ -262,6 +264,7 @@ class OpenAIHttpClient(OpenAIClient):
await self._finish_request(request)
async def _finish_request(self, request: UCompletionRequest) -> None:
self._metrics_collector.completed_requests.inc()
await self._router.finish_request(request)
async def collect_metrics(self) -> Dict[str, Any]:

View File

@ -57,11 +57,12 @@ class RawRequestResponseHooks(ResponseHooks):
self.raw_req = raw_req
self.ctx_server = ""
self.gen_server = ""
self.request_arrival_time = raw_req.state.server_arrival_time
self.server_first_token_time = 0
self.perf_metrics_collector = perf_metrics_collector
def on_req_begin(self, request: UCompletionRequest):
...
self.perf_metrics_collector.queue_latency_seconds.observe(get_steady_clock_now_in_seconds() - self.request_arrival_time)
def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse):
self.ctx_server = ctx_server
@ -93,8 +94,8 @@ class OpenAIDisaggServer:
self._metrics_interval_secs = metrics_interval_secs
self._ctx_servers, self._gen_servers = get_ctx_gen_server_addrs(config.server_configs)
self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock)
self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock)
self._metadata_server = create_metadata_server(metadata_server_cfg)
self._perf_metrics_collector = DisaggPerfMetricsCollector(config.perf_metrics_max_requests)
@ -122,8 +123,10 @@ class OpenAIDisaggServer:
@asynccontextmanager
async def lifespan(app) -> None:
# Prepare servers (sync server clock) when static ctx/gen server list is used
await self._ctx_router.prepare_servers()
await self._gen_router.prepare_servers()
await self._service.setup()
await self._set_steady_clock_offsets()
yield
await self._service.teardown()
@ -133,6 +136,7 @@ class OpenAIDisaggServer:
@self.app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
self._perf_metrics_collector.validation_exceptions.inc()
return JSONResponse(status_code=400, content={"error": str(exc)})
self.register_routes()
@ -158,8 +162,14 @@ class OpenAIDisaggServer:
def _wrap_entry_point(self, entry_point: Callable) -> Callable:
async def wrapper(req: UCompletionRequest, raw_req: Request) -> Response:
try:
self._perf_metrics_collector.total_requests.inc()
if req.stream:
self._perf_metrics_collector.stream_requests.inc()
else:
self._perf_metrics_collector.nonstream_requests.inc()
hooks = RawRequestResponseHooks(raw_req, self._perf_metrics_collector)
response_or_generator = await entry_point(req, hooks)
self._perf_metrics_collector.total_responses.inc()
if req.stream:
return StreamingResponse(content=response_or_generator, media_type="text/event-stream")
else:
@ -173,9 +183,11 @@ class OpenAIDisaggServer:
logger.error("CppExecutorError: ", traceback.format_exc())
signal.raise_signal(signal.SIGINT)
elif isinstance(exception, HTTPException):
self._perf_metrics_collector.http_exceptions.inc()
logger.error(f"HTTPException {exception.status_code} {exception.detail}: ", traceback.format_exc())
raise exception
else:
self._perf_metrics_collector.internal_errors.inc()
logger.error("Internal server error: ", traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Internal server error {str(exception)}")
@ -199,13 +211,12 @@ class OpenAIDisaggServer:
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
await uvicorn.Server(config).serve(sockets=sockets)
# TODO: rework this for service discovery, now it's only for static server list
async def _set_steady_clock_offsets(self):
STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset"
async def _sync_server_clock(self, server: str):
""" Sync the ctx/gen server's steady clock with the disagg-server's steady clock (in case NTP service is not running). """
async def query_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> tuple[Optional[float], Optional[float]]:
try:
originate_ts = get_steady_clock_now_in_seconds()
async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response:
async with session.get(server_url) as response:
destination_ts = get_steady_clock_now_in_seconds()
if response.status == 200:
response_content = await response.json()
@ -222,12 +233,11 @@ class OpenAIDisaggServer:
async def set_steady_clock_offset(session: aiohttp.ClientSession, server_url: str, offset: float) -> None:
payload = {"offset": offset}
async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response:
async with session.post(server_url, json=payload) as response:
if response.status != 200:
logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned")
async def align_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> None:
server_url = f"http://{server_url}" if not server_url.startswith("http://") else server_url
delay, offset = await query_steady_clock_offset(session, server_url)
if delay is None or offset is None:
logger.warning(f"Unable to measure steady clock offset for {server_url}; skipping adjustment")
@ -236,7 +246,13 @@ class OpenAIDisaggServer:
# Negate the offset so that worker servers can adjust their steady clock by adding the new offset
await set_steady_clock_offset(session, server_url, -offset)
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
timeout=aiohttp.ClientTimeout(total=self._req_timeout_secs)) as session:
await asyncio.gather(*[align_steady_clock_offset(session, server_url) for server_url in self._ctx_servers + self._gen_servers])
server_scheme = "http://" if not server.startswith("http://") else ""
server_url = f"{server_scheme}{server}/steady_clock_offset"
try:
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
timeout=aiohttp.ClientTimeout(total=self._req_timeout_secs)) as session:
await align_steady_clock_offset(session, server_url)
except (aiohttp.ClientError, OSError) as e:
logger.warning(f"Unable to align steady clock offset for {server_url}: {e}; skipping adjustment")

View File

@ -15,7 +15,7 @@
import asyncio
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional
from tensorrt_llm.llmapi.disagg_utils import ServerRole
@ -64,7 +64,7 @@ class MetricsDefinition:
buckets: Optional[List[float]] = None
METRICS_DEFINITIONS = [
CLIENT_METRICS_DEFINITIONS = [
MetricsDefinition("total_requests", "Total number of requests", "counter"),
MetricsDefinition("error_requests", "Total number of error requests", "counter"),
MetricsDefinition("retry_requests", "Total number of retry requests", "counter"),
@ -96,23 +96,29 @@ ROLE_TO_CLIENT_TYPE = {
}
def instance_metric(definition: MetricsDefinition, role: Optional[ServerRole] = None):
# import lazily to avoid breaking `set_prometheus_multiproc_dir`
from prometheus_client import Counter, Histogram
name = (
f"{ROLE_TO_CLIENT_TYPE[role]}_{definition.name}"
if role in ROLE_TO_CLIENT_TYPE
else definition.name
)
if definition.type == "counter":
return Counter(name, definition.description)
elif definition.type == "histogram":
return Histogram(name, definition.description, buckets=definition.buckets)
else:
raise ValueError(f"Invalid metric type: {definition.type}")
class ClientMetricsCollector:
def __init__(self, role: ServerRole):
self._role = role
# import lazily to avoid breaking `set_prometheus_multiproc_dir`
from prometheus_client import Counter, Histogram
def instance_metric(definition: MetricsDefinition) -> Union[Counter | Histogram]:
name = f"{ROLE_TO_CLIENT_TYPE[role]}_{definition.name}"
if definition.type == "counter":
return Counter(name, definition.description)
elif definition.type == "histogram":
return Histogram(name, definition.description, buckets=definition.buckets)
else:
raise ValueError(f"Invalid metric type: {definition.type}")
self._metrics = {
definition.name: instance_metric(definition) for definition in METRICS_DEFINITIONS
definition.name: instance_metric(definition, role)
for definition in CLIENT_METRICS_DEFINITIONS
}
def __getattr__(
@ -121,6 +127,23 @@ class ClientMetricsCollector:
return self._metrics[key]
SERVER_METRICS_DEFINITIONS = [
MetricsDefinition("total_requests", "Total number of requests", "counter"),
MetricsDefinition("stream_requests", "Total number of stream requests", "counter"),
MetricsDefinition("nonstream_requests", "Total number of non-stream requests", "counter"),
MetricsDefinition("validation_exceptions", "Total number of validation exceptions", "counter"),
MetricsDefinition("http_exceptions", "Total number of HTTP exceptions", "counter"),
MetricsDefinition("internal_errors", "Total number of internal errors", "counter"),
MetricsDefinition("total_responses", "Total number of responses", "counter"),
MetricsDefinition(
"queue_latency_seconds",
"Histogram of latency from request arrival to being processed in seconds",
"histogram",
SHORT_TIME_BUCKETS,
),
]
class DisaggPerfMetricsCollector:
def __init__(self, max_requests: int):
self._max_requests = max_requests
@ -128,10 +151,17 @@ class DisaggPerfMetricsCollector:
self._server_metrics = defaultdict(dict)
self._lock = asyncio.Lock()
self._clients = []
self._metrics = {
definition.name: instance_metric(definition)
for definition in SERVER_METRICS_DEFINITIONS
}
def add_client(self, client):
self._clients.append(client)
def __getattr__(self, key: str):
return self._metrics[key]
async def add_per_request_metrics(
self,
ctx_server: str,

View File

@ -1,7 +1,7 @@
import asyncio
import heapq
from abc import ABC, abstractmethod
from typing import Dict, Iterable, List, Optional, Union
from typing import Awaitable, Callable, Dict, Iterable, List, Optional, Union
import aiohttp
from transformers import AutoTokenizer
@ -145,9 +145,15 @@ class KvCacheAwareServerState(ServerState):
class Router(ABC):
def __init__(self, server_role: ServerRole, servers: List[str],
metadata_server_cfg: Optional[MetadataServerConfig],
metadata_server: Optional[JsonDictionary]):
def __init__(
self,
server_role: ServerRole,
servers: List[str],
metadata_server_cfg: Optional[MetadataServerConfig],
metadata_server: Optional[JsonDictionary],
server_preparation_func: Optional[Callable[[str],
Awaitable[None]]] = None,
**kwargs):
self._servers = servers or []
self._metadata_server = metadata_server
self._server_role = server_role
@ -155,6 +161,7 @@ class Router(ABC):
self._monitor_task = None
self._session = None
self._health_check_timeout = metadata_server_cfg.health_check_timeout if metadata_server_cfg else None
self._server_preparation_func = server_preparation_func
@abstractmethod
def _on_servers_updated(self, old_servers, new_servers):
@ -169,16 +176,26 @@ class Router(ABC):
def servers(self) -> List[str]:
return self._servers
async def _prepare_server(self, server: str):
if self._server_preparation_func:
await self._server_preparation_func(server)
async def prepare_servers(self, servers: Optional[List[str]] = None):
for server in servers or self._servers:
await self._prepare_server(server)
async def add_server(self, server: str):
if server in self._servers:
logger.warning(f"Server {server} already exists")
return
await self._prepare_server(server)
async with self._lock:
old_servers = self._servers.copy()
self._servers = [*old_servers, server]
self._on_servers_updated(old_servers, self._servers)
logger.debug(
f"Added server {server}, current server list: {self._servers}")
f"Added server {server}, {self._server_role.name} current server list: {self._servers}"
)
async def remove_server(self, server: str):
if server not in self._servers:
@ -275,6 +292,7 @@ class Router(ABC):
# Log added servers
for server in final_servers:
if server not in old_servers:
await self._prepare_server(server)
logger.info(f"Server {server} is added")
else:
logger.debug(
@ -419,7 +437,7 @@ class RoundRobinRouter(Router):
metadata_server: JsonDictionary = None,
**kwargs):
super().__init__(server_role, servers, metadata_server_cfg,
metadata_server)
metadata_server, **kwargs)
self._server_idx = 0
def _on_servers_updated(self, old_servers, new_servers):
@ -463,7 +481,7 @@ class LoadBalancingRouter(Router):
use_tokens: bool = False,
**kwargs):
super().__init__(server_role, servers, metadata_server_cfg,
metadata_server)
metadata_server, **kwargs)
# Load map between servers and their number of tokens processed
self._server_state = {}
self._server_load_heap = []
@ -550,7 +568,7 @@ class KvCacheAwareRouter(Router):
tokens_per_block: int = 32,
**kwargs):
super().__init__(server_role, servers, metadata_server_cfg,
metadata_server)
metadata_server, **kwargs)
self._lock = asyncio.Lock()
self._use_tokens = use_tokens
@ -647,10 +665,13 @@ class KvCacheAwareRouter(Router):
self._server_state.pop(old_server, None)
def create_router(router_config: Optional[RouterConfig],
servers: Optional[List[str]],
metadata_server_cfg: Optional[MetadataServerConfig] = None,
metadata_server: Optional[JsonDictionary] = None) -> Router:
def create_router(
router_config: Optional[RouterConfig],
servers: Optional[List[str]],
metadata_server_cfg: Optional[MetadataServerConfig] = None,
metadata_server: Optional[JsonDictionary] = None,
server_preparation_func: Optional[Callable[[str], Awaitable[None]]] = None
) -> Router:
"""
Factory function to create different types of router instances.
@ -681,5 +702,8 @@ def create_router(router_config: Optional[RouterConfig],
extra_args = router_config.args if router_config else {}
return router_class(router_config.server_role if router_config else None,
servers, metadata_server_cfg, metadata_server,
servers,
metadata_server_cfg,
metadata_server,
server_preparation_func=server_preparation_func,
**extra_args)

View File

@ -154,7 +154,7 @@ def _run_worker(model_name, worker_config, role, port, work_dir, device=-1):
env = os.environ.copy()
if device != -1:
env["CUDA_VISIBLE_DEVICES"] = str(device)
log_path = os.path.join(work_dir, f"output_{role}.log")
log_path = os.path.join(work_dir, f"output_{role}_{port}.log")
log_file = open(log_path, "w+")
print(f"Running {role} on port {port}")
return ProcessWrapper(subprocess.Popen(cmd,

View File

@ -27,6 +27,8 @@ from defs.common import (revise_disagg_config_file_with_free_ports,
from defs.conftest import (get_sm_version, llm_models_root, skip_arm,
skip_no_hopper)
from defs.trt_test_alternative import check_call, check_output, popen
from test_common.perf_metrics_utils import (get_timing_metrics,
validate_timing_metrics)
from tensorrt_llm._utils import get_free_port, mpi_disabled
from tensorrt_llm.logger import logger
@ -41,112 +43,6 @@ def cleanup_output_files():
pass
def validate_timing_metrics(perf_metrics_item, request_context=""):
"""
Helper function to validate timing metrics relationships.
Args:
perf_metrics_item: A single performance metrics item from the /perf_metrics endpoint
request_context: String context for error messages (e.g., "request 1", "streaming")
"""
# Validate basic structure
required_keys = [
"ctx_server", "gen_server", "ctx_perf_metrics", "gen_perf_metrics",
"disagg_server_arrival_time", "disagg_server_first_token_time"
]
for key in required_keys:
assert key in perf_metrics_item, f"Missing key: {key} in {request_context}"
assert perf_metrics_item["ctx_perf_metrics"][
"ctx_request_id"] == perf_metrics_item["gen_perf_metrics"][
"ctx_request_id"]
# Extract timing metrics
ctx_metrics = perf_metrics_item["ctx_perf_metrics"]["perf_metrics"][
"timing_metrics"]
gen_metrics = perf_metrics_item["gen_perf_metrics"]["perf_metrics"][
"timing_metrics"]
disagg_arrival = perf_metrics_item["disagg_server_arrival_time"]
disagg_first_token = perf_metrics_item["disagg_server_first_token_time"]
# Validate disaggregated server timing metrics
assert disagg_arrival is not None, f"disagg_server_arrival_time is None in {request_context}"
assert disagg_first_token is not None, f"disagg_server_first_token_time is None in {request_context}"
assert isinstance(
disagg_arrival,
(int, float
)), f"disagg_server_arrival_time is not numeric in {request_context}"
assert isinstance(
disagg_first_token, (int, float)
), f"disagg_server_first_token_time is not numeric in {request_context}"
assert disagg_arrival > 0, f"disagg_server_arrival_time is not positive in {request_context}"
assert disagg_first_token > 0, f"disagg_server_first_token_time is not positive in {request_context}"
assert disagg_arrival <= disagg_first_token, f"disagg_server_arrival_time > disagg_server_first_token_time in {request_context}"
# Validate server-level timing metrics for context server
ctx_server_arrival = ctx_metrics.get("server_arrival_time")
ctx_server_first_token = ctx_metrics.get("server_first_token_time")
assert ctx_server_arrival is not None, f"ctx server_arrival_time is None in {request_context}"
assert ctx_server_first_token is not None, f"ctx server_first_token_time is None in {request_context}"
assert isinstance(
ctx_server_arrival,
(int,
float)), f"ctx server_arrival_time is not numeric in {request_context}"
assert isinstance(
ctx_server_first_token,
(int, float
)), f"ctx server_first_token_time is not numeric in {request_context}"
assert ctx_server_arrival <= ctx_server_first_token, f"ctx server_arrival_time > server_first_token_time in {request_context}"
assert ctx_metrics["last_token_time"] - ctx_server_first_token < 1e-3
# Validate server-level timing metrics for generation server
gen_server_arrival = gen_metrics.get("server_arrival_time")
gen_server_first_token = gen_metrics.get("server_first_token_time")
assert gen_server_arrival is not None, f"gen server_arrival_time is None in {request_context}"
assert gen_server_first_token is not None, f"gen server_first_token_time is None in {request_context}"
assert isinstance(
gen_server_arrival,
(int,
float)), f"gen server_arrival_time is not numeric in {request_context}"
assert isinstance(
gen_server_first_token,
(int, float
)), f"gen server_first_token_time is not numeric in {request_context}"
assert gen_server_arrival <= gen_server_first_token, f"gen server_arrival_time > server_first_token_time in {request_context}"
# Network Time Protocol can ensure ms-level accuracy in LAN
ntp_tolerance = 1e-3
# Validate timing relationships between different levels
# Disaggregated server should receive request before individual servers
assert disagg_arrival - ntp_tolerance <= ctx_server_arrival, f"disagg_arrival > ctx_server_arrival in {request_context}"
assert disagg_arrival - ntp_tolerance <= gen_server_arrival, f"disagg_arrival > gen_server_arrival in {request_context}"
# Context should complete before generation starts
assert ctx_server_first_token - ntp_tolerance <= gen_server_arrival, f"ctx_server_first_token > gen_server_arrival in {request_context}"
# Validate internal timing consistency
ctx_arrival_time = ctx_metrics["arrival_time"]
ctx_first_token_time = ctx_metrics["first_token_time"]
gen_arrival_time = gen_metrics["arrival_time"]
gen_first_token_time = gen_metrics["first_token_time"]
assert ctx_arrival_time <= ctx_first_token_time, f"ctx arrival_time > first_token_time in {request_context}"
assert gen_arrival_time <= gen_first_token_time, f"gen arrival_time > first_token_time in {request_context}"
# Test KV cache transfer timing (if present)
if "kv_cache_transfer_start" in gen_metrics and "kv_cache_transfer_end" in gen_metrics:
kv_start = gen_metrics["kv_cache_transfer_start"]
kv_end = gen_metrics["kv_cache_transfer_end"]
assert gen_metrics["kv_cache_size"] > 0
assert kv_start <= kv_end, f"kv_cache_transfer_start > kv_cache_transfer_end in {request_context}"
assert gen_arrival_time <= kv_start, f"gen_arrival_time > kv_cache_transfer_start in {request_context}"
assert kv_end <= gen_metrics[
"first_scheduled_time"], f"kv_cache_transfer_end > first_scheduled_time in {request_context}"
return True
def get_disagg_server_url_from_cfg(config_file: str) -> tuple[str, int]:
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
@ -828,16 +724,7 @@ def test_disaggregated_perf_metrics(disaggregated_test_root, llm_venv,
os.symlink(src, dst, target_is_directory=True)
def extra_endpoints_test(server_url: str):
import json
import urllib.request
with urllib.request.urlopen(f"{server_url}/perf_metrics",
timeout=10) as resp:
assert resp.status == 200
perf_metrics = json.load(resp)
assert len(perf_metrics) > 0
item = perf_metrics[0]
item = get_timing_metrics(server_url)
# Use helper function to validate all timing metrics comprehensively
validate_timing_metrics(item, "perf_metrics test")

View File

@ -9,7 +9,6 @@ These tests verify that trtllm-serve handles error conditions gracefully:
"""
import asyncio
import socket
import time
from pathlib import Path
@ -19,11 +18,7 @@ import requests
from defs.conftest import llm_models_root
from defs.trt_test_alternative import popen, print_error, print_info
def _find_free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
from tensorrt_llm._utils import get_free_port
class RemoteOpenAIServer:
@ -63,7 +58,7 @@ def server(model_path):
"""Start a test server for the module using popen like test_serve.py"""
host_bind = "0.0.0.0"
client_host = "localhost"
port = _find_free_port()
port = get_free_port()
cmd = [
"trtllm-serve",
"serve",

View File

@ -31,6 +31,7 @@ from _pytest.nodes import Item
from _pytest.python import Function
from defs.trt_test_alternative import (check_output, popen, print_error,
print_info)
from test_common.http_utils import wait_for_endpoint_ready
from tensorrt_llm._utils import get_free_port
@ -251,29 +252,6 @@ class PerfAggrScriptTestCmds(NamedTuple):
timeout: int
output_dir: str
def wait_for_endpoint_ready(self, url: str, timeout: int = 7200):
start = time.monotonic()
while True:
elapsed_time = time.monotonic() - start
if elapsed_time > timeout:
print_error(
f"Timeout waiting for endpoint {url} to be ready after {timeout} seconds"
)
break
try:
print_info(
f"Waiting for endpoint {url} to be ready, elapsed time: {elapsed_time}s"
)
time.sleep(1)
if requests.get(url).status_code == 200:
print_info(f"endpoint {url} is ready")
return
except Exception as err:
print_info(
f"endpoint {url} is not ready, with exception: {err}")
print_error(
f"Endpoint {url} did not become ready within {timeout} seconds")
def run_cmd(self, cmd_idx: int, venv) -> str:
output = ""
server_proc = None
@ -294,7 +272,7 @@ class PerfAggrScriptTestCmds(NamedTuple):
stderr=subprocess.STDOUT,
env=copy.deepcopy(os.environ),
)
self.wait_for_endpoint_ready(
wait_for_endpoint_ready(
f"http://{server_hostname}:{server_port}/health",
timeout=self.timeout)
client_cmd = add_host_port_to_cmd(self.client_cmds[cmd_idx],
@ -323,19 +301,6 @@ class PerfDisaggScriptTestCmds(NamedTuple):
client_cmd: List[str]
benchmark_cmd: List[str]
def wait_for_endpoint_ready(self, url: str, timeout: int = 600):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
time.sleep(1)
if requests.get(url).status_code == 200:
print(f"endpoint {url} is ready")
return
except Exception as err:
print(f"endpoint {url} is not ready, with exception: {err}")
print_error(
f"Endpoint {url} did not become ready within {timeout} seconds")
def run_cmd(self, cmd_idx: int, venv) -> str:
output = ""
try:
@ -360,7 +325,7 @@ class PerfDisaggScriptTestCmds(NamedTuple):
stderr=subprocess.STDOUT,
env=venv._new_env,
shell=True) as server_proc):
self.wait_for_endpoint_ready(
wait_for_endpoint_ready(
f"http://localhost:8000/health",
timeout=1800) # 30 minutes for large models
check_output(self.client_cmd, env=venv._new_env)

View File

@ -5,7 +5,7 @@ threadleak_exclude = asyncio_\d+
junit_family=legacy
addopts = --ignore-glob="*perf/test_perf.py" --ignore-glob="*perf/disagg/*" --ignore-glob="*test_list_validation.py" --ignore-glob="*llm-test-workspace*" --durations=0 -W ignore::DeprecationWarning
pythonpath =
../../../examples/auto_deploy
../../../examples/auto_deploy ../../
norecursedirs = ./triton/perf ./perf/disagg
markers =
skip_less_device: skip when less device detected than the declared

View File

@ -1769,6 +1769,21 @@ def test_trtllm_multimodal_benchmark_serving(llm_root, llm_venv):
])
@pytest.mark.skip_less_device(4)
@pytest.mark.skip_less_device_memory(40000)
@pytest.mark.parametrize("service_discovery", ["etcd", "http"])
def test_openai_disagg_multi_nodes_completion_service_discovery(
llm_root, llm_venv, service_discovery):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd([
"-m",
"pytest",
str(test_root /
f"_test_disagg_serving_multi_nodes_service_discovery.py::test_completion[{service_discovery}]"
),
])
@pytest.mark.skip_less_device(4)
@pytest.mark.skip_less_device_memory(40000)
@pytest.mark.parametrize("gen_config",

View File

@ -11,3 +11,5 @@ test_e2e.py::test_multi_nodes_eval[Kimi-K2-Instruct-tp16-mmlu]
test_e2e.py::test_multi_nodes_eval[nemotron-nas/Llama-3_1-Nemotron-Ultra-253B-v1-tp16-mmlu]
test_e2e.py::test_openai_disagg_multi_nodes_completion[ctx_tp2pp1-gen_tp2pp1]
test_e2e.py::test_openai_disagg_multi_nodes_completion[ctx_tp1pp2-gen_tp1pp2]
test_e2e.py::test_openai_disagg_multi_nodes_completion_service_discovery[http]
test_e2e.py::test_openai_disagg_multi_nodes_completion_service_discovery[etcd]

View File

@ -42,6 +42,7 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
# ------------- AutoDeploy tests ---------------
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
# llmapi

View File

@ -4,12 +4,11 @@ import ast
import os
import subprocess
import sys
import time
from pathlib import Path
from typing import Dict, List, NamedTuple
import requests
import yaml
from test_common.http_utils import wait_for_endpoint_ready
def get_node_name() -> str:
@ -568,19 +567,6 @@ class PerfServerBenchmarkCmds(NamedTuple):
names: List[str]
working_dir: str
def wait_for_endpoint_ready(self, url: str, timeout: int = 5400):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
time.sleep(10)
if requests.get(url, timeout=5).status_code == 200:
print(f"endpoint {url} is ready")
return
except Exception as err:
print(f"endpoint {url} is not ready, with exception: {err}")
print_error(
f"Endpoint {url} did not become ready within {timeout} seconds")
def run_cmd(self,
cmd_idx: int,
node_name: str,
@ -601,8 +587,8 @@ class PerfServerBenchmarkCmds(NamedTuple):
stderr=subprocess.STDOUT)
# Wait for server to be ready
self.wait_for_endpoint_ready("http://localhost:8000/v1/models",
timeout=max_timeout)
wait_for_endpoint_ready("http://localhost:8000/v1/models",
timeout=max_timeout)
# Save node name, gpu info, server config, client config output to server file path
with open(client_file_path, 'w') as client_ctx:

View File

View File

@ -0,0 +1,29 @@
import time
import requests
def wait_for_endpoint_ready(url: str, timeout: int = 300):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
time.sleep(1)
if requests.get(url, timeout=5).status_code == 200:
print(f"endpoint {url} is ready")
return
except Exception as err:
print(f"endpoint {url} is not ready, with exception: {err}")
raise RuntimeError(f"Endpoint {url} did not become ready within {timeout} seconds")
def wait_for_endpoint_down(url: str, timeout: int = 300):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
if requests.get(url, timeout=5).status_code >= 100:
print(f"endpoint {url} returned status code {requests.get(url).status_code}")
time.sleep(1)
except Exception as err:
print(f"endpoint {url} is down, with exception: {err}")
return
raise RuntimeError(f"Endpoint {url} did not become down within {timeout} seconds")

View File

@ -0,0 +1,188 @@
import requests
def get_timing_metrics(server_url: str):
response = requests.get(f"{server_url}/perf_metrics", timeout=10)
assert response.status_code == 200
perf_metrics = response.json()
assert len(perf_metrics) > 0
return perf_metrics[0]
def validate_timing_metrics(perf_metrics_item, request_context="", time_tolerance_seconds=0.005):
"""Helper function to validate timing metrics relationships.
Args:
perf_metrics_item: A single performance metrics item from the /perf_metrics endpoint
request_context: String context for error messages (e.g., "request 1", "streaming")
"""
# Validate basic structure
required_keys = [
"ctx_server",
"gen_server",
"ctx_perf_metrics",
"gen_perf_metrics",
"disagg_server_arrival_time",
"disagg_server_first_token_time",
]
for key in required_keys:
assert key in perf_metrics_item, f"Missing key: {key} in {request_context}"
assert (
perf_metrics_item["ctx_perf_metrics"]["ctx_request_id"]
== perf_metrics_item["gen_perf_metrics"]["ctx_request_id"]
)
# Extract timing metrics
ctx_metrics = perf_metrics_item["ctx_perf_metrics"]["perf_metrics"]["timing_metrics"]
gen_metrics = perf_metrics_item["gen_perf_metrics"]["perf_metrics"]["timing_metrics"]
disagg_arrival = perf_metrics_item["disagg_server_arrival_time"]
disagg_first_token = perf_metrics_item["disagg_server_first_token_time"]
# Validate disaggregated server timing metrics
assert disagg_arrival is not None, f"disagg_server_arrival_time is None in {request_context}"
assert disagg_first_token is not None, (
f"disagg_server_first_token_time is None in {request_context}"
)
assert isinstance(disagg_arrival, (int, float)), (
f"disagg_server_arrival_time is not numeric in {request_context}"
)
assert isinstance(disagg_first_token, (int, float)), (
f"disagg_server_first_token_time is not numeric in {request_context}"
)
assert disagg_arrival > 0, f"disagg_server_arrival_time is not positive in {request_context}"
assert disagg_first_token > 0, (
f"disagg_server_first_token_time is not positive in {request_context}"
)
assert disagg_arrival <= disagg_first_token, (
f"disagg_server_arrival_time > disagg_server_first_token_time in {request_context}"
)
# Validate server-level timing metrics for context server
ctx_server_arrival = ctx_metrics.get("server_arrival_time")
ctx_server_first_token = ctx_metrics.get("server_first_token_time")
assert ctx_server_arrival is not None, f"ctx server_arrival_time is None in {request_context}"
assert ctx_server_first_token is not None, (
f"ctx server_first_token_time is None in {request_context}"
)
assert isinstance(ctx_server_arrival, (int, float)), (
f"ctx server_arrival_time is not numeric in {request_context}"
)
assert isinstance(ctx_server_first_token, (int, float)), (
f"ctx server_first_token_time is not numeric in {request_context}"
)
assert ctx_server_arrival <= ctx_server_first_token, (
f"ctx server_arrival_time > server_first_token_time in {request_context}"
)
assert ctx_metrics["last_token_time"] - ctx_server_first_token < 1e-3
# Validate server-level timing metrics for generation server
gen_server_arrival = gen_metrics.get("server_arrival_time")
gen_server_first_token = gen_metrics.get("server_first_token_time")
assert gen_server_arrival is not None, f"gen server_arrival_time is None in {request_context}"
assert gen_server_first_token is not None, (
f"gen server_first_token_time is None in {request_context}"
)
assert isinstance(gen_server_arrival, (int, float)), (
f"gen server_arrival_time is not numeric in {request_context}"
)
assert isinstance(gen_server_first_token, (int, float)), (
f"gen server_first_token_time is not numeric in {request_context}"
)
assert gen_server_arrival <= gen_server_first_token, (
f"gen server_arrival_time > server_first_token_time in {request_context}"
)
# Validate timing relationships between different levels
# Disaggregated server should receive request before individual servers
# Allow some tolerance of a local network ping time when comparing the times from disagg and ctx/gen servers
# by taking consideration of the error of NTP (1/2 ping time).
assert disagg_arrival <= ctx_server_arrival + time_tolerance_seconds, (
f"disagg_arrival {disagg_arrival} > ctx_server_arrival {ctx_server_arrival} in {request_context}"
)
assert disagg_arrival <= gen_server_arrival + time_tolerance_seconds, (
f"disagg_arrival {disagg_arrival} > gen_server_arrival {gen_server_arrival} in {request_context}"
)
# Context should complete before generation starts
assert ctx_server_first_token <= gen_server_arrival + time_tolerance_seconds, (
f"ctx_server_first_token > gen_server_arrival in {request_context}"
)
# Validate internal timing consistency
ctx_arrival_time = ctx_metrics["arrival_time"]
ctx_first_token_time = ctx_metrics["first_token_time"]
gen_arrival_time = gen_metrics["arrival_time"]
gen_first_token_time = gen_metrics["first_token_time"]
assert ctx_arrival_time <= ctx_first_token_time, (
f"ctx arrival_time > first_token_time in {request_context}"
)
assert gen_arrival_time <= gen_first_token_time, (
f"gen arrival_time > first_token_time in {request_context}"
)
# Test KV cache transfer timing (if present)
if "kv_cache_transfer_start" in gen_metrics and "kv_cache_transfer_end" in gen_metrics:
kv_start = gen_metrics["kv_cache_transfer_start"]
kv_end = gen_metrics["kv_cache_transfer_end"]
assert gen_metrics["kv_cache_size"] > 0
assert kv_start <= kv_end, (
f"kv_cache_transfer_start > kv_cache_transfer_end in {request_context}"
)
assert gen_arrival_time <= kv_start, (
f"gen_arrival_time > kv_cache_transfer_start in {request_context}"
)
assert kv_end <= gen_metrics["first_scheduled_time"], (
f"kv_cache_transfer_end > first_scheduled_time in {request_context}"
)
return True
def get_prometheus_metrics(server_url: str):
response = requests.get(server_url + "/prometheus/metrics")
assert response.status_code == 200
# Parse Prometheus metrics lines into a dictionary of {metric_name: value}
metrics = {}
print(response.text)
for line in response.text.split("\n"):
if line.startswith("#") or not line.strip():
continue
parts = line.split()
if len(parts) < 2:
continue
metric = parts[0]
try:
value = float(parts[1])
except ValueError:
continue
import re
if bucket_match := re.match(r'(.+)_bucket\{le="([^"]+)"\}', metric):
# Try to parse bucket boundaries out of metrics like ..._bucket{le="0.005"}
base_metric, le_value = bucket_match.groups()
if base_metric not in metrics:
metrics[base_metric] = {}
try:
metrics[base_metric][float(le_value)] = value
except ValueError:
continue
elif sum_match := re.match(r"(.+)_sum$", metric):
base_metric = sum_match.groups()[0]
if base_metric not in metrics:
metrics[base_metric] = {}
metrics[base_metric]["sum"] = value
elif count_match := re.match(r"(.+)_count$", metric):
base_metric = count_match.groups()[0]
if base_metric not in metrics:
metrics[base_metric] = {}
metrics[base_metric]["count"] = value
elif total_match := re.match(r"(.+)_total$", metric):
base_metric = total_match.groups()[0]
print(f"Total metric {metric}: {base_metric} = {value}")
metrics[base_metric] = value
else:
# ignore prometheus built-in metrics
pass
return metrics

View File

@ -4,11 +4,15 @@ import time
import openai
import pytest
import requests
from test_common.http_utils import (wait_for_endpoint_down,
wait_for_endpoint_ready)
from test_common.perf_metrics_utils import (get_timing_metrics,
validate_timing_metrics)
from ..test_llm import get_model_path
from .openai_server import RemoteDisaggOpenAIServer, RemoteOpenAIServer
from .utils import expand_slurm_nodelist
from .utils import (expand_slurm_nodelist, wait_for_endpoint_down,
wait_for_endpoint_ready)
RANK = int(os.environ.get("SLURM_PROCID", 0))
NODE_RANK = int(os.environ.get("SLURM_NODEID", 0))
@ -19,7 +23,8 @@ pytestmark = pytest.mark.threadleak(enabled=False)
# This test assumes that there are >2 nodes, we run ctx/disagg-server/client on the first node,
# and run gen the second node.
# This is a multi-node test, and will not be scheduled to the same node running other tests
# using fixed ports should be safe.
CTX_SERVER_PORT = 8001
GEN_SERVER_PORT = 8002
DISAGG_SERVER_PORT = 8000
@ -65,6 +70,7 @@ def env():
k: v
for k, v in os.environ.items()
if not ('PMI_' in k or 'OMPI_' in k or 'PMIX_' in k or 'SLURM_' in k)
and k not in ["UCX_TLS", "UCX_NET_DEVICES"] # avoid UCX failure on oci
}
@ -105,6 +111,8 @@ def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
"enable_block_reuse": False,
},
"disable_overlap_scheduler": True,
"perf_metrics_max_requests": 1000,
"return_perf_metrics": True,
}
if is_ctx_node():
print(f"starting ctx_server for rank {RANK} node rank {NODE_RANK}")
@ -138,32 +146,6 @@ def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
yield None
def wait_for_endpoint_ready(url: str, timeout: int = 300):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
time.sleep(1)
if requests.get(url).status_code == 200:
print(f"endpoint {url} is ready")
return
except Exception as err:
print(f"endpoint {url} is not ready, with exception: {err}")
def wait_for_endpoint_down(url: str, timeout: int = 300):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
if requests.get(url).status_code >= 100:
print(
f"endpoint {url} returned status code {requests.get(url).status_code}"
)
time.sleep(1)
except Exception as err:
print(f"endpoint {url} is down, with exception: {err}")
return
@pytest.fixture(scope="module")
def disagg_server(worker: RemoteOpenAIServer):
if is_disagg_node():
@ -210,6 +192,14 @@ def test_completion(client: openai.OpenAI,
assert completion.id is not None
message = completion.choices[0].text
assert message.startswith('2.')
perf_metrics = get_timing_metrics(disagg_server.url_root)
# allow 5ms leniency when comparing the time points from disagg and ctx/gen servers
validate_timing_metrics(perf_metrics,
"multinode test_completion",
time_leniency_seconds=0.005)
# sleep 10 seconds to ensure a successful wait_for_endpoint_ready on rank1
time.sleep(10)
disagg_server.terminate()
elif is_gen_node():

View File

@ -0,0 +1,220 @@
import os
import shutil
import subprocess
import tempfile
import uuid
import openai
import pytest
from test_common.perf_metrics_utils import get_timing_metrics, validate_timing_metrics
from tensorrt_llm._utils import get_free_port
from tensorrt_llm.llmapi.disagg_utils import ServerRole
from ..test_llm import get_model_path
from .openai_server import RemoteDisaggOpenAIServer, RemoteOpenAIServer
from .utils import expand_slurm_nodelist, wait_for_endpoint_down, wait_for_endpoint_ready
RANK = int(os.environ.get("SLURM_PROCID", 0))
NODE_RANK = int(os.environ.get("SLURM_NODEID", 0))
NODE_LIST = expand_slurm_nodelist(os.environ.get("SLURM_NODELIST", ""))
SLURM_NTASKS_PER_NODE = int(os.environ.get("SLURM_NTASKS_PER_NODE", 1))
# This a multi-node QA test, use a fixed port instead of finding a free port
# so that all nodes can have the same disagg server config
DISAGG_SERVER_PORT = 8000
# This test is supposed to run with 2 nodes or more
def is_ctx_node():
assert len(NODE_LIST) == 2
return NODE_RANK == 0
def is_gen_node():
assert len(NODE_LIST) == 2
return NODE_RANK == 1
def is_disagg_node():
return NODE_RANK == 0
# The test is run on multinodes but only the first node's output is used for assertion
def is_pytest_node():
return NODE_RANK == 0
def env():
# Remove MPI related environment variables to isolate the ctx/gen processes
# so that they will not be in the same MPI communicator, otherwise the rank and world_size may mismatch
return {
k: v
for k, v in os.environ.items()
if not ("PMI_" in k or "OMPI_" in k or "PMIX_" in k or "SLURM_" in k)
and k not in ["UCX_TLS", "UCX_NET_DEVICES"]
}
@pytest.fixture
def model_name():
return "llama-3.1-model/Llama-3.1-8B-Instruct"
@pytest.fixture
def disagg_host():
return NODE_LIST[0]
@pytest.fixture(params=["etcd", "http"])
def service_discovery(request, disagg_host: str):
if request.param == "etcd":
work_dir = tempfile.mkdtemp()
data_dir = f"{work_dir}/disagg_test-etcd-{uuid.uuid4()}"
etcd = subprocess.Popen(["etcd", "--data-dir", data_dir])
yield etcd, f"etcd://{disagg_host}:2379"
try:
etcd.kill()
etcd.wait(timeout=10)
shutil.rmtree(data_dir)
except Exception:
pass
else:
yield None, f"http://{disagg_host}:{DISAGG_SERVER_PORT}"
@pytest.fixture
def disagg_cluster_config(service_discovery: tuple):
_, uri = service_discovery
return {
"cluster_uri": uri,
"cluster_name": "",
}
@pytest.fixture
def worker(model_name: str, disagg_cluster_config: dict):
extra_config = {
"disagg_cluster": disagg_cluster_config,
"cache_transceiver_config": {"backend": "DEFAULT"},
"kv_cache_config": {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False,
},
"disable_overlap_scheduler": True,
"return_perf_metrics": True,
"perf_metrics_max_requests": 1000,
}
# start workers on 0.0.0.0:<free_port>, then the workers should be able to
# report their correct hostname:port to the disagg server
port = get_free_port()
if is_ctx_node():
print(f"starting ctx_server for rank {RANK} node rank {NODE_RANK}")
model_path = get_model_path(model_name)
tp_size, pp_size = 1, 1
args = ["--tp_size", str(tp_size), "--pp_size", str(pp_size)]
with RemoteOpenAIServer(
model_path,
port=port,
cli_args=args,
host="0.0.0.0",
env=env(),
llmapi_launch=False,
rank=RANK % SLURM_NTASKS_PER_NODE,
extra_config=extra_config,
role=ServerRole.CONTEXT,
) as server:
yield server
elif is_gen_node():
print(f"starting gen_server for rank {RANK} node rank {NODE_RANK}")
model_path = get_model_path(model_name)
tp_size, pp_size = 1, 1
args = ["--tp_size", str(tp_size), "--pp_size", str(pp_size)]
with RemoteOpenAIServer(
model_path,
port=port,
cli_args=args,
host="0.0.0.0",
env=env(),
llmapi_launch=False,
rank=RANK % SLURM_NTASKS_PER_NODE,
extra_config=extra_config,
role=ServerRole.GENERATION,
) as server:
yield server
else:
yield None
# different from non-service-discovery version, disagg server doesn't have to
# wait for ctx/gen servers to get ready
@pytest.fixture
def disagg_server(disagg_cluster_config: dict):
if is_disagg_node():
disagg_config = {
"disagg_cluster": disagg_cluster_config,
"port": DISAGG_SERVER_PORT,
"hostname": "0.0.0.0",
"perf_metrics_max_requests": 1000,
}
print(f"starting disagg_server for rank {RANK} node rank {NODE_RANK}")
# ctx/gen servers are unnecessary for service discovery test
with RemoteDisaggOpenAIServer(
ctx_servers=[],
gen_servers=[],
port=DISAGG_SERVER_PORT,
disagg_config=disagg_config,
llmapi_launch=False,
env=env(),
wait_ready=False, # wait it to be ready in test body
) as server:
yield server
else:
print(f"skipping disagg_server for rank {RANK} node rank {NODE_RANK}")
yield None
@pytest.fixture
def client(disagg_server: RemoteDisaggOpenAIServer):
if is_pytest_node():
return disagg_server.get_client()
else:
print(f"skipping client for rank {RANK} node rank {NODE_RANK}")
return None
def test_completion(
disagg_server: RemoteDisaggOpenAIServer,
worker: RemoteOpenAIServer,
client: openai.OpenAI,
disagg_host: str,
model_name: str,
):
disagg_health_url = f"http://{disagg_host}:{DISAGG_SERVER_PORT}/health/"
wait_for_endpoint_ready(disagg_health_url)
if is_pytest_node():
print(f"running test_completion on rank {RANK} node rank {NODE_RANK}")
prompt = "What is the result of 1+1? Answer in one word: "
for _ in range(10):
completion = client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=10,
temperature=0.0,
)
print(f"Output: {completion.choices[0].text}")
assert completion.id is not None
message = completion.choices[0].text
assert message.startswith("2.")
perf_metrics = get_timing_metrics(disagg_server.url_root)
validate_timing_metrics(perf_metrics, "multinode test_completion")
disagg_server.terminate()
elif is_gen_node():
# keep gen workers alive until the test ends
wait_for_endpoint_down(disagg_health_url)
assert True
else:
assert True

View File

@ -1,6 +1,5 @@
import os
import tempfile
from typing import List
import openai
import pytest
@ -8,42 +7,11 @@ import requests
import yaml
from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
from .openai_server import RemoteMMEncoderServer
pytestmark = pytest.mark.threadleak(enabled=False)
class RemoteMMEncoderServer(RemoteOpenAIServer):
"""Remote server for testing multimodal encoder endpoints."""
def __init__(self,
model: str,
cli_args: List[str] = None,
port: int = None) -> None:
# Reuse parent initialization but change the command
import subprocess
import sys
from tensorrt_llm.llmapi.mpi_session import find_free_port
self.host = "localhost"
self.port = port if port is not None else find_free_port()
self.rank = os.environ.get("SLURM_PROCID", 0)
args = ["--host", f"{self.host}", "--port", f"{self.port}"]
if cli_args:
args += cli_args
# Use mm_embedding_serve command instead of regular serve
launch_cmd = ["trtllm-serve", "mm_embedding_serve"] + [model] + args
self.proc = subprocess.Popen(launch_cmd,
stdout=sys.stdout,
stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S)
@pytest.fixture(scope="module", ids=["Qwen2.5-VL-3B-Instruct"])
def model_name():
return "Qwen2.5-VL-3B-Instruct"

View File

@ -11,7 +11,8 @@ import openai
import requests
import yaml
from tensorrt_llm.llmapi.mpi_session import find_free_port
from tensorrt_llm._utils import get_free_port
from tensorrt_llm.llmapi.disagg_utils import ServerRole
class RemoteOpenAIServer:
@ -26,13 +27,21 @@ class RemoteOpenAIServer:
host: str = "localhost",
env: Optional[dict] = None,
rank: int = -1,
extra_config: Optional[dict] = None) -> None:
extra_config: Optional[dict] = None,
log_path: Optional[str] = None,
wait: bool = True,
role: Optional[ServerRole] = None) -> None:
self.host = host
self.port = port if port is not None else find_free_port()
self.port = port if port is not None else get_free_port()
self.rank = rank if rank != -1 else int(
os.environ.get("SLURM_PROCID", 0))
self.extra_config_file = None
self.log_path = log_path
self.log_file = None
self.role = role
args = ["--host", f"{self.host}", "--port", f"{self.port}"]
if self.role is not None:
args += ["--server_role", self.role.name]
if cli_args:
args += cli_args
if extra_config:
@ -50,10 +59,19 @@ class RemoteOpenAIServer:
env = os.environ.copy()
self.proc = subprocess.Popen(launch_cmd,
env=env,
stdout=sys.stdout,
stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S)
stdout=self._get_output(),
stderr=self._get_output())
if wait:
self.wait_for_server(timeout=self.MAX_SERVER_START_WAIT_S)
def _get_output(self):
if self.log_file:
return self.log_file
elif self.log_path:
self.log_file = open(self.log_path, "w+")
return self.log_file
else:
return sys.stdout
def __enter__(self):
return self
@ -76,6 +94,12 @@ class RemoteOpenAIServer:
except Exception as e:
print(f"Error removing extra config file: {e}")
self.proc = None
if self.log_file:
self.log_file.close()
self.log_file = None
def wait_for_server(self, timeout: float):
self._wait_for_server(url=self.url_for("health"), timeout=timeout)
def _wait_for_server(self, *, url: str, timeout: float):
# run health check on the first rank only.
@ -128,21 +152,28 @@ class RemoteDisaggOpenAIServer(RemoteOpenAIServer):
gen_servers: List[str],
port: int = -1,
env: Optional[dict] = None,
llmapi_launch: bool = False) -> None:
llmapi_launch: bool = False,
disagg_config: Optional[dict] = None,
log_path: Optional[str] = None,
wait_ready: bool = True) -> None:
self.ctx_servers = ctx_servers
self.gen_servers = gen_servers
self.host = "localhost"
self.port = find_free_port() if port is None or port < 0 else port
self.host = "0.0.0.0"
self.port = get_free_port() if port is None or port < 0 else port
self.rank = 0
with tempfile.NamedTemporaryFile(mode="w+",
delete=False,
delete_on_close=False) as f:
f.write(self._get_extra_config())
f.flush()
self.extra_config_file = f.name
self.disagg_config = self._get_extra_config()
if disagg_config:
self.disagg_config.update(disagg_config)
self.log_path = log_path
self.log_file = None
self.extra_config_file = os.path.join(
tempfile.gettempdir(), f"disagg_config_{self.port}.yaml")
with open(self.extra_config_file, "w+") as f:
yaml.dump(self.disagg_config, f)
launch_cmd = [
"trtllm-serve", "disaggregated", "-c", self.extra_config_file
]
print(f"launch_cmd: {launch_cmd}, extra_config: {self.disagg_config}")
if llmapi_launch:
# start server with llmapi-launch on multi nodes
launch_cmd = ["trtllm-llmapi-launch"] + launch_cmd
@ -150,13 +181,14 @@ class RemoteDisaggOpenAIServer(RemoteOpenAIServer):
env = os.environ.copy()
self.proc = subprocess.Popen(launch_cmd,
env=env,
stdout=sys.stdout,
stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S)
stdout=self._get_output(),
stderr=self._get_output())
if wait_ready:
self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S)
def _get_extra_config(self):
return yaml.dump({
return {
"context_servers": {
"num_instances": len(self.ctx_servers),
"urls": self.ctx_servers
@ -167,4 +199,38 @@ class RemoteDisaggOpenAIServer(RemoteOpenAIServer):
},
"port": self.port,
"hostname": self.host,
})
"perf_metrics_max_requests": 1000,
}
class RemoteMMEncoderServer(RemoteOpenAIServer):
"""Remote server for testing multimodal encoder endpoints."""
def __init__(self,
model: str,
cli_args: List[str] = None,
port: int = None,
log_path: Optional[str] = None) -> None:
# Reuse parent initialization but change the command
import subprocess
from tensorrt_llm._utils import get_free_port
self.host = "localhost"
self.port = port if port is not None else get_free_port()
self.rank = os.environ.get("SLURM_PROCID", 0)
self.log_path = log_path
self.log_file = None
args = ["--host", f"{self.host}", "--port", f"{self.port}"]
if cli_args:
args += cli_args
# Use mm_embedding_serve command instead of regular serve
launch_cmd = ["trtllm-serve", "mm_embedding_serve"] + [model] + args
self.proc = subprocess.Popen(launch_cmd,
stdout=self._get_output(),
stderr=self._get_output())
self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S)

View File

@ -0,0 +1,219 @@
import os
from typing import Tuple
import openai
import pytest
from test_common.http_utils import wait_for_endpoint_ready
from test_common.perf_metrics_utils import (
get_prometheus_metrics,
get_timing_metrics,
validate_timing_metrics,
)
from tensorrt_llm._utils import get_free_ports
from ..test_llm import get_model_path
from .openai_server import RemoteDisaggOpenAIServer, RemoteOpenAIServer
@pytest.fixture
def test_ports():
return get_free_ports(3)
@pytest.fixture
def disagg_port(test_ports: list[int]):
return test_ports[0]
@pytest.fixture
def ctx_port(test_ports: list[int]):
return test_ports[1]
@pytest.fixture
def gen_port(test_ports: list[int]):
return test_ports[2]
@pytest.fixture
def model_name():
return "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@pytest.fixture
def disagg_cluster_config(disagg_port: int):
return {
"cluster_uri": f"http://localhost:{disagg_port}",
"cluster_name": "",
}
def worker_config(model_name: str, disagg_cluster_config: dict):
return {
"model": model_name,
"disagg_cluster": disagg_cluster_config,
"cache_transceiver_config": {
"backend": "DEFAULT",
},
"kv_cache_config": {
"free_gpu_memory_fraction": 0.2,
"enable_block_reuse": False,
},
"disable_overlap_scheduler": True,
"cuda_graph_config": None,
"return_perf_metrics": True,
"perf_metrics_max_requests": 1000,
}
@pytest.fixture
def workers(model_name: str, disagg_cluster_config: dict, ctx_port: int, gen_port: int):
model_path = get_model_path(model_name)
extra_config = worker_config(model_name, disagg_cluster_config)
def worker(server_role: str, port: int):
return RemoteOpenAIServer(
model_path,
port=port,
env=os.environ.copy(),
cli_args=["--server_role", server_role],
llmapi_launch=False,
extra_config=extra_config,
log_path=f"output_{server_role}.log",
wait=False,
)
with worker("context", ctx_port) as ctx_worker, worker("generation", gen_port) as gen_worker:
yield ctx_worker, gen_worker
@pytest.fixture
def disagg_server(disagg_cluster_config: dict, workers, disagg_port: int):
disagg_config = {
"port": disagg_port,
"disagg_cluster": disagg_cluster_config,
"perf_metrics_max_requests": 1000,
}
with RemoteDisaggOpenAIServer(
ctx_servers=[],
gen_servers=[],
port=disagg_config["port"],
llmapi_launch=False,
disagg_config=disagg_config,
) as server:
yield server
@pytest.fixture
def client(disagg_server: RemoteDisaggOpenAIServer):
return disagg_server.get_client()
@pytest.fixture
def async_client(disagg_server: RemoteDisaggOpenAIServer):
return disagg_server.get_async_client()
async def send_request(
client: openai.AsyncOpenAI, stream: bool, repeat: int, max_token: int, model_name: str
):
for _ in range(repeat):
prompt = "What is the result of 1+1? Answer in one word: "
completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=max_token,
temperature=0.0,
stream=stream,
)
if stream:
output = []
async for chunk in completion:
output.append(chunk.choices[0].text)
assert len(output) > 0
message = "".join(output)
else:
assert completion.id is not None
message = completion.choices[0].text
assert message.startswith("2.")
def check_historgram(metrics_dict: dict, count: int, range: tuple[float, float]):
assert metrics_dict["count"] == count
mean = metrics_dict["sum"] / metrics_dict["count"]
assert mean > range[0] and mean < range[1]
@pytest.mark.asyncio
@pytest.mark.timeout(300)
async def test_completion_metrics(
async_client: openai.AsyncOpenAI,
workers: Tuple[RemoteOpenAIServer, RemoteOpenAIServer],
disagg_server: RemoteDisaggOpenAIServer,
model_name: str,
):
assert len(workers) == 2
for worker in workers:
worker.wait_for_server(timeout=120)
wait_for_endpoint_ready(disagg_server.url_root + "/health")
max_token = 10
total_requests = 10
await send_request(
client=async_client,
stream=True,
repeat=total_requests,
max_token=max_token,
model_name=model_name,
)
timing_metrics = get_timing_metrics(disagg_server.url_root)
validate_timing_metrics(timing_metrics, "test_completion_metrics")
metrics = get_prometheus_metrics(disagg_server.url_root)
print(metrics)
for role in ["ctx", "gen"]:
assert metrics[f"{role}_total_requests"] == total_requests
assert metrics[f"{role}_completed_requests"] == total_requests
assert metrics[f"{role}_error_requests"] == 0
assert f"{role}_retry_requests" in metrics
check_historgram(metrics["gen_first_token_latency_seconds"], total_requests, (0.0, 0.3))
check_historgram(metrics["gen_complete_latency_seconds"], total_requests, (0.0, 0.6))
assert metrics["total_requests"] == total_requests
assert metrics["stream_requests"] == total_requests
assert metrics["nonstream_requests"] == 0
assert metrics["total_responses"] == total_requests
assert metrics["validation_exceptions"] == 0
assert metrics["http_exceptions"] == 0
assert metrics["internal_errors"] == 0
check_historgram(metrics["queue_latency_seconds"], total_requests, (0.0, 0.03))
# test non streaming part
await send_request(
client=async_client,
stream=False,
repeat=total_requests,
max_token=max_token,
model_name=model_name,
)
metrics = get_prometheus_metrics(disagg_server.url_root)
for role in ["ctx", "gen"]:
assert metrics[f"{role}_total_requests"] == total_requests * 2
assert metrics[f"{role}_completed_requests"] == total_requests * 2
assert metrics[f"{role}_error_requests"] == 0
assert f"{role}_retry_requests" in metrics
assert metrics["total_requests"] == total_requests * 2
assert metrics["stream_requests"] == total_requests
assert metrics["nonstream_requests"] == total_requests
assert metrics["total_responses"] == total_requests * 2
assert metrics["validation_exceptions"] == 0
assert metrics["http_exceptions"] == 0
assert metrics["internal_errors"] == 0
check_historgram(metrics["gen_complete_latency_seconds"], total_requests * 2, (0.0, 0.6))
check_historgram(metrics["queue_latency_seconds"], total_requests * 2, (0.0, 0.03))

View File

@ -14,10 +14,12 @@
# limitations under the License.
import re
import time
from pathlib import Path
from typing import Any, Callable
import pytest
import requests
import yaml
from ..test_llm import get_model_path
@ -257,3 +259,29 @@ def expand_slurm_nodelist(nodelist_str):
expanded_nodes.append(group)
return expanded_nodes
def wait_for_endpoint_ready(url: str, timeout: int = 300, interval: int = 3):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
time.sleep(interval)
if requests.get(url).status_code == 200:
print(f"endpoint {url} is ready")
return
except Exception as err:
print(f"endpoint {url} is not ready, with exception: {err}")
def wait_for_endpoint_down(url: str, timeout: int = 300):
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
if requests.get(url).status_code >= 100:
print(
f"endpoint {url} returned status code {requests.get(url).status_code}"
)
time.sleep(1)
except Exception as err:
print(f"endpoint {url} is down, with exception: {err}")
return

View File

@ -9,6 +9,7 @@ pythonpath =
../../examples/auto_deploy
../../examples/models/core
../../examples
../
env =
D:AUTO_DEPLOY_LOG_LEVEL=INFO
markers =