mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9181][feat] improve disagg-server prometheus metrics; synchronize workers' clocks when workers are dynamic (#9726)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
609d1d0383
commit
bd13957e70
@ -473,10 +473,20 @@ def dim_resolve_negative(dim, ndim):
|
||||
return tuple(pos)
|
||||
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket() as sock:
|
||||
sock.bind(("", 0))
|
||||
return sock.getsockname()[1]
|
||||
def get_free_port() -> int:
|
||||
return get_free_ports(1)[0]
|
||||
|
||||
|
||||
def get_free_ports(num=1) -> List[int]:
|
||||
sockets = [
|
||||
socket.socket(socket.AF_INET, socket.SOCK_STREAM) for _ in range(num)
|
||||
]
|
||||
for s in sockets:
|
||||
s.bind(('', 0))
|
||||
ports = [s.getsockname()[1] for s in sockets]
|
||||
for s in sockets:
|
||||
s.close()
|
||||
return ports
|
||||
|
||||
|
||||
# mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here
|
||||
|
||||
@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
@ -29,6 +30,18 @@ def get_worker_key(name: str, role: ServerRole, worker_id: str = "") -> str:
|
||||
return f"{get_worker_key_prefix(name)}/{worker_id}"
|
||||
|
||||
|
||||
def get_host_from_uri(uri: str) -> str:
|
||||
return uri.split("://")[1].split(":")[0]
|
||||
|
||||
|
||||
# Get the local ip address from a remote host,
|
||||
# if remote host is not provided, use Google's public DNS server "8.8.8.8"
|
||||
def get_local_ip(remote_host: str = "8.8.8.8") -> str:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect((remote_host, 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
class DisaggClusterManager:
|
||||
"""
|
||||
The cluster manager is responsible for managing the workers in the cluster.
|
||||
@ -238,18 +251,25 @@ class DisaggClusterWorker:
|
||||
It will send heartbeat to the cluster storage every heartbeat_interval_sec seconds.
|
||||
If the worker heartbeat fails, it will re-register itself.
|
||||
"""
|
||||
LOCALHOST_IPS = ["localhost", "127.0.0.1", "0.0.0.0", "::1",
|
||||
"::"] # nosec B104
|
||||
|
||||
def __init__(self, role: ServerRole, host: str, port: int,
|
||||
config: DisaggClusterConfig, storage: ClusterStorage):
|
||||
self._role = role
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._config = config
|
||||
self._cluster_storage = storage
|
||||
self._stop = False
|
||||
self._heartbeat_task = None
|
||||
self._last_heartbeat = 0
|
||||
self._worker_id = f"{role.name}-{host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
|
||||
register_host = host
|
||||
# if the host is localhost and the cluster uri is not localhost, use the hostname to register the worker
|
||||
disagg_host = get_host_from_uri(self._config.cluster_uri)
|
||||
if host in self.LOCALHOST_IPS and disagg_host not in self.LOCALHOST_IPS:
|
||||
register_host = get_local_ip(disagg_host)
|
||||
self._host = register_host
|
||||
self._worker_id = f"{role.name}-{register_host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
|
||||
@ -183,6 +183,9 @@ class OpenAIHttpClient(OpenAIClient):
|
||||
yield response_dict
|
||||
# finish the request after the successful response
|
||||
await self._finish_request(request)
|
||||
self._metrics_collector.complete_latency_seconds.observe(
|
||||
get_steady_clock_now_in_seconds() - start_time
|
||||
)
|
||||
break # break and skip retries if the whole response is processed without exception
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
if lines_yielded > 0:
|
||||
@ -227,25 +230,24 @@ class OpenAIHttpClient(OpenAIClient):
|
||||
i = 0
|
||||
async for line in http_response.content.iter_any():
|
||||
now_time = get_steady_clock_now_in_seconds()
|
||||
if i == 0:
|
||||
if hooks:
|
||||
hooks.on_first_token(server, request)
|
||||
self._metrics_collector.first_token_latency_seconds.observe(
|
||||
now_time - last_token_time
|
||||
)
|
||||
else:
|
||||
self._metrics_collector.per_token_latency_seconds.observe(
|
||||
now_time - last_token_time
|
||||
)
|
||||
i += 1
|
||||
if line:
|
||||
if i == 0:
|
||||
if hooks:
|
||||
hooks.on_first_token(server, request)
|
||||
self._metrics_collector.first_token_latency_seconds.observe(
|
||||
now_time - last_token_time
|
||||
)
|
||||
else:
|
||||
self._metrics_collector.per_token_latency_seconds.observe(
|
||||
now_time - last_token_time
|
||||
)
|
||||
i += 1
|
||||
yield line
|
||||
await asyncio.sleep(0)
|
||||
last_token_time = now_time
|
||||
|
||||
if hooks:
|
||||
hooks.on_resp_done(server, request, None)
|
||||
self._metrics_collector.completed_requests.inc()
|
||||
self._metrics_collector.complete_latency_seconds.observe(
|
||||
get_steady_clock_now_in_seconds() - start_time
|
||||
)
|
||||
@ -262,6 +264,7 @@ class OpenAIHttpClient(OpenAIClient):
|
||||
await self._finish_request(request)
|
||||
|
||||
async def _finish_request(self, request: UCompletionRequest) -> None:
|
||||
self._metrics_collector.completed_requests.inc()
|
||||
await self._router.finish_request(request)
|
||||
|
||||
async def collect_metrics(self) -> Dict[str, Any]:
|
||||
|
||||
@ -57,11 +57,12 @@ class RawRequestResponseHooks(ResponseHooks):
|
||||
self.raw_req = raw_req
|
||||
self.ctx_server = ""
|
||||
self.gen_server = ""
|
||||
self.request_arrival_time = raw_req.state.server_arrival_time
|
||||
self.server_first_token_time = 0
|
||||
self.perf_metrics_collector = perf_metrics_collector
|
||||
|
||||
def on_req_begin(self, request: UCompletionRequest):
|
||||
...
|
||||
self.perf_metrics_collector.queue_latency_seconds.observe(get_steady_clock_now_in_seconds() - self.request_arrival_time)
|
||||
|
||||
def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse):
|
||||
self.ctx_server = ctx_server
|
||||
@ -93,8 +94,8 @@ class OpenAIDisaggServer:
|
||||
self._metrics_interval_secs = metrics_interval_secs
|
||||
|
||||
self._ctx_servers, self._gen_servers = get_ctx_gen_server_addrs(config.server_configs)
|
||||
self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
|
||||
self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
|
||||
self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock)
|
||||
self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock)
|
||||
self._metadata_server = create_metadata_server(metadata_server_cfg)
|
||||
self._perf_metrics_collector = DisaggPerfMetricsCollector(config.perf_metrics_max_requests)
|
||||
|
||||
@ -122,8 +123,10 @@ class OpenAIDisaggServer:
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app) -> None:
|
||||
# Prepare servers (sync server clock) when static ctx/gen server list is used
|
||||
await self._ctx_router.prepare_servers()
|
||||
await self._gen_router.prepare_servers()
|
||||
await self._service.setup()
|
||||
await self._set_steady_clock_offsets()
|
||||
yield
|
||||
await self._service.teardown()
|
||||
|
||||
@ -133,6 +136,7 @@ class OpenAIDisaggServer:
|
||||
|
||||
@self.app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
self._perf_metrics_collector.validation_exceptions.inc()
|
||||
return JSONResponse(status_code=400, content={"error": str(exc)})
|
||||
|
||||
self.register_routes()
|
||||
@ -158,8 +162,14 @@ class OpenAIDisaggServer:
|
||||
def _wrap_entry_point(self, entry_point: Callable) -> Callable:
|
||||
async def wrapper(req: UCompletionRequest, raw_req: Request) -> Response:
|
||||
try:
|
||||
self._perf_metrics_collector.total_requests.inc()
|
||||
if req.stream:
|
||||
self._perf_metrics_collector.stream_requests.inc()
|
||||
else:
|
||||
self._perf_metrics_collector.nonstream_requests.inc()
|
||||
hooks = RawRequestResponseHooks(raw_req, self._perf_metrics_collector)
|
||||
response_or_generator = await entry_point(req, hooks)
|
||||
self._perf_metrics_collector.total_responses.inc()
|
||||
if req.stream:
|
||||
return StreamingResponse(content=response_or_generator, media_type="text/event-stream")
|
||||
else:
|
||||
@ -173,9 +183,11 @@ class OpenAIDisaggServer:
|
||||
logger.error("CppExecutorError: ", traceback.format_exc())
|
||||
signal.raise_signal(signal.SIGINT)
|
||||
elif isinstance(exception, HTTPException):
|
||||
self._perf_metrics_collector.http_exceptions.inc()
|
||||
logger.error(f"HTTPException {exception.status_code} {exception.detail}: ", traceback.format_exc())
|
||||
raise exception
|
||||
else:
|
||||
self._perf_metrics_collector.internal_errors.inc()
|
||||
logger.error("Internal server error: ", traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error {str(exception)}")
|
||||
|
||||
@ -199,13 +211,12 @@ class OpenAIDisaggServer:
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
await uvicorn.Server(config).serve(sockets=sockets)
|
||||
|
||||
# TODO: rework this for service discovery, now it's only for static server list
|
||||
async def _set_steady_clock_offsets(self):
|
||||
STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset"
|
||||
async def _sync_server_clock(self, server: str):
|
||||
""" Sync the ctx/gen server's steady clock with the disagg-server's steady clock (in case NTP service is not running). """
|
||||
async def query_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> tuple[Optional[float], Optional[float]]:
|
||||
try:
|
||||
originate_ts = get_steady_clock_now_in_seconds()
|
||||
async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response:
|
||||
async with session.get(server_url) as response:
|
||||
destination_ts = get_steady_clock_now_in_seconds()
|
||||
if response.status == 200:
|
||||
response_content = await response.json()
|
||||
@ -222,12 +233,11 @@ class OpenAIDisaggServer:
|
||||
|
||||
async def set_steady_clock_offset(session: aiohttp.ClientSession, server_url: str, offset: float) -> None:
|
||||
payload = {"offset": offset}
|
||||
async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response:
|
||||
async with session.post(server_url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned")
|
||||
|
||||
async def align_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> None:
|
||||
server_url = f"http://{server_url}" if not server_url.startswith("http://") else server_url
|
||||
delay, offset = await query_steady_clock_offset(session, server_url)
|
||||
if delay is None or offset is None:
|
||||
logger.warning(f"Unable to measure steady clock offset for {server_url}; skipping adjustment")
|
||||
@ -236,7 +246,13 @@ class OpenAIDisaggServer:
|
||||
# Negate the offset so that worker servers can adjust their steady clock by adding the new offset
|
||||
await set_steady_clock_offset(session, server_url, -offset)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
|
||||
timeout=aiohttp.ClientTimeout(total=self._req_timeout_secs)) as session:
|
||||
await asyncio.gather(*[align_steady_clock_offset(session, server_url) for server_url in self._ctx_servers + self._gen_servers])
|
||||
server_scheme = "http://" if not server.startswith("http://") else ""
|
||||
server_url = f"{server_scheme}{server}/steady_clock_offset"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
|
||||
timeout=aiohttp.ClientTimeout(total=self._req_timeout_secs)) as session:
|
||||
await align_steady_clock_offset(session, server_url)
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
logger.warning(f"Unable to align steady clock offset for {server_url}: {e}; skipping adjustment")
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
import asyncio
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from tensorrt_llm.llmapi.disagg_utils import ServerRole
|
||||
|
||||
@ -64,7 +64,7 @@ class MetricsDefinition:
|
||||
buckets: Optional[List[float]] = None
|
||||
|
||||
|
||||
METRICS_DEFINITIONS = [
|
||||
CLIENT_METRICS_DEFINITIONS = [
|
||||
MetricsDefinition("total_requests", "Total number of requests", "counter"),
|
||||
MetricsDefinition("error_requests", "Total number of error requests", "counter"),
|
||||
MetricsDefinition("retry_requests", "Total number of retry requests", "counter"),
|
||||
@ -96,23 +96,29 @@ ROLE_TO_CLIENT_TYPE = {
|
||||
}
|
||||
|
||||
|
||||
def instance_metric(definition: MetricsDefinition, role: Optional[ServerRole] = None):
|
||||
# import lazily to avoid breaking `set_prometheus_multiproc_dir`
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
name = (
|
||||
f"{ROLE_TO_CLIENT_TYPE[role]}_{definition.name}"
|
||||
if role in ROLE_TO_CLIENT_TYPE
|
||||
else definition.name
|
||||
)
|
||||
if definition.type == "counter":
|
||||
return Counter(name, definition.description)
|
||||
elif definition.type == "histogram":
|
||||
return Histogram(name, definition.description, buckets=definition.buckets)
|
||||
else:
|
||||
raise ValueError(f"Invalid metric type: {definition.type}")
|
||||
|
||||
|
||||
class ClientMetricsCollector:
|
||||
def __init__(self, role: ServerRole):
|
||||
self._role = role
|
||||
# import lazily to avoid breaking `set_prometheus_multiproc_dir`
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
def instance_metric(definition: MetricsDefinition) -> Union[Counter | Histogram]:
|
||||
name = f"{ROLE_TO_CLIENT_TYPE[role]}_{definition.name}"
|
||||
if definition.type == "counter":
|
||||
return Counter(name, definition.description)
|
||||
elif definition.type == "histogram":
|
||||
return Histogram(name, definition.description, buckets=definition.buckets)
|
||||
else:
|
||||
raise ValueError(f"Invalid metric type: {definition.type}")
|
||||
|
||||
self._metrics = {
|
||||
definition.name: instance_metric(definition) for definition in METRICS_DEFINITIONS
|
||||
definition.name: instance_metric(definition, role)
|
||||
for definition in CLIENT_METRICS_DEFINITIONS
|
||||
}
|
||||
|
||||
def __getattr__(
|
||||
@ -121,6 +127,23 @@ class ClientMetricsCollector:
|
||||
return self._metrics[key]
|
||||
|
||||
|
||||
SERVER_METRICS_DEFINITIONS = [
|
||||
MetricsDefinition("total_requests", "Total number of requests", "counter"),
|
||||
MetricsDefinition("stream_requests", "Total number of stream requests", "counter"),
|
||||
MetricsDefinition("nonstream_requests", "Total number of non-stream requests", "counter"),
|
||||
MetricsDefinition("validation_exceptions", "Total number of validation exceptions", "counter"),
|
||||
MetricsDefinition("http_exceptions", "Total number of HTTP exceptions", "counter"),
|
||||
MetricsDefinition("internal_errors", "Total number of internal errors", "counter"),
|
||||
MetricsDefinition("total_responses", "Total number of responses", "counter"),
|
||||
MetricsDefinition(
|
||||
"queue_latency_seconds",
|
||||
"Histogram of latency from request arrival to being processed in seconds",
|
||||
"histogram",
|
||||
SHORT_TIME_BUCKETS,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class DisaggPerfMetricsCollector:
|
||||
def __init__(self, max_requests: int):
|
||||
self._max_requests = max_requests
|
||||
@ -128,10 +151,17 @@ class DisaggPerfMetricsCollector:
|
||||
self._server_metrics = defaultdict(dict)
|
||||
self._lock = asyncio.Lock()
|
||||
self._clients = []
|
||||
self._metrics = {
|
||||
definition.name: instance_metric(definition)
|
||||
for definition in SERVER_METRICS_DEFINITIONS
|
||||
}
|
||||
|
||||
def add_client(self, client):
|
||||
self._clients.append(client)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
return self._metrics[key]
|
||||
|
||||
async def add_per_request_metrics(
|
||||
self,
|
||||
ctx_server: str,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import heapq
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from transformers import AutoTokenizer
|
||||
@ -145,9 +145,15 @@ class KvCacheAwareServerState(ServerState):
|
||||
|
||||
class Router(ABC):
|
||||
|
||||
def __init__(self, server_role: ServerRole, servers: List[str],
|
||||
metadata_server_cfg: Optional[MetadataServerConfig],
|
||||
metadata_server: Optional[JsonDictionary]):
|
||||
def __init__(
|
||||
self,
|
||||
server_role: ServerRole,
|
||||
servers: List[str],
|
||||
metadata_server_cfg: Optional[MetadataServerConfig],
|
||||
metadata_server: Optional[JsonDictionary],
|
||||
server_preparation_func: Optional[Callable[[str],
|
||||
Awaitable[None]]] = None,
|
||||
**kwargs):
|
||||
self._servers = servers or []
|
||||
self._metadata_server = metadata_server
|
||||
self._server_role = server_role
|
||||
@ -155,6 +161,7 @@ class Router(ABC):
|
||||
self._monitor_task = None
|
||||
self._session = None
|
||||
self._health_check_timeout = metadata_server_cfg.health_check_timeout if metadata_server_cfg else None
|
||||
self._server_preparation_func = server_preparation_func
|
||||
|
||||
@abstractmethod
|
||||
def _on_servers_updated(self, old_servers, new_servers):
|
||||
@ -169,16 +176,26 @@ class Router(ABC):
|
||||
def servers(self) -> List[str]:
|
||||
return self._servers
|
||||
|
||||
async def _prepare_server(self, server: str):
|
||||
if self._server_preparation_func:
|
||||
await self._server_preparation_func(server)
|
||||
|
||||
async def prepare_servers(self, servers: Optional[List[str]] = None):
|
||||
for server in servers or self._servers:
|
||||
await self._prepare_server(server)
|
||||
|
||||
async def add_server(self, server: str):
|
||||
if server in self._servers:
|
||||
logger.warning(f"Server {server} already exists")
|
||||
return
|
||||
await self._prepare_server(server)
|
||||
async with self._lock:
|
||||
old_servers = self._servers.copy()
|
||||
self._servers = [*old_servers, server]
|
||||
self._on_servers_updated(old_servers, self._servers)
|
||||
logger.debug(
|
||||
f"Added server {server}, current server list: {self._servers}")
|
||||
f"Added server {server}, {self._server_role.name} current server list: {self._servers}"
|
||||
)
|
||||
|
||||
async def remove_server(self, server: str):
|
||||
if server not in self._servers:
|
||||
@ -275,6 +292,7 @@ class Router(ABC):
|
||||
# Log added servers
|
||||
for server in final_servers:
|
||||
if server not in old_servers:
|
||||
await self._prepare_server(server)
|
||||
logger.info(f"Server {server} is added")
|
||||
else:
|
||||
logger.debug(
|
||||
@ -419,7 +437,7 @@ class RoundRobinRouter(Router):
|
||||
metadata_server: JsonDictionary = None,
|
||||
**kwargs):
|
||||
super().__init__(server_role, servers, metadata_server_cfg,
|
||||
metadata_server)
|
||||
metadata_server, **kwargs)
|
||||
self._server_idx = 0
|
||||
|
||||
def _on_servers_updated(self, old_servers, new_servers):
|
||||
@ -463,7 +481,7 @@ class LoadBalancingRouter(Router):
|
||||
use_tokens: bool = False,
|
||||
**kwargs):
|
||||
super().__init__(server_role, servers, metadata_server_cfg,
|
||||
metadata_server)
|
||||
metadata_server, **kwargs)
|
||||
# Load map between servers and their number of tokens processed
|
||||
self._server_state = {}
|
||||
self._server_load_heap = []
|
||||
@ -550,7 +568,7 @@ class KvCacheAwareRouter(Router):
|
||||
tokens_per_block: int = 32,
|
||||
**kwargs):
|
||||
super().__init__(server_role, servers, metadata_server_cfg,
|
||||
metadata_server)
|
||||
metadata_server, **kwargs)
|
||||
self._lock = asyncio.Lock()
|
||||
self._use_tokens = use_tokens
|
||||
|
||||
@ -647,10 +665,13 @@ class KvCacheAwareRouter(Router):
|
||||
self._server_state.pop(old_server, None)
|
||||
|
||||
|
||||
def create_router(router_config: Optional[RouterConfig],
|
||||
servers: Optional[List[str]],
|
||||
metadata_server_cfg: Optional[MetadataServerConfig] = None,
|
||||
metadata_server: Optional[JsonDictionary] = None) -> Router:
|
||||
def create_router(
|
||||
router_config: Optional[RouterConfig],
|
||||
servers: Optional[List[str]],
|
||||
metadata_server_cfg: Optional[MetadataServerConfig] = None,
|
||||
metadata_server: Optional[JsonDictionary] = None,
|
||||
server_preparation_func: Optional[Callable[[str], Awaitable[None]]] = None
|
||||
) -> Router:
|
||||
"""
|
||||
Factory function to create different types of router instances.
|
||||
|
||||
@ -681,5 +702,8 @@ def create_router(router_config: Optional[RouterConfig],
|
||||
extra_args = router_config.args if router_config else {}
|
||||
|
||||
return router_class(router_config.server_role if router_config else None,
|
||||
servers, metadata_server_cfg, metadata_server,
|
||||
servers,
|
||||
metadata_server_cfg,
|
||||
metadata_server,
|
||||
server_preparation_func=server_preparation_func,
|
||||
**extra_args)
|
||||
|
||||
@ -154,7 +154,7 @@ def _run_worker(model_name, worker_config, role, port, work_dir, device=-1):
|
||||
env = os.environ.copy()
|
||||
if device != -1:
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(device)
|
||||
log_path = os.path.join(work_dir, f"output_{role}.log")
|
||||
log_path = os.path.join(work_dir, f"output_{role}_{port}.log")
|
||||
log_file = open(log_path, "w+")
|
||||
print(f"Running {role} on port {port}")
|
||||
return ProcessWrapper(subprocess.Popen(cmd,
|
||||
|
||||
@ -27,6 +27,8 @@ from defs.common import (revise_disagg_config_file_with_free_ports,
|
||||
from defs.conftest import (get_sm_version, llm_models_root, skip_arm,
|
||||
skip_no_hopper)
|
||||
from defs.trt_test_alternative import check_call, check_output, popen
|
||||
from test_common.perf_metrics_utils import (get_timing_metrics,
|
||||
validate_timing_metrics)
|
||||
|
||||
from tensorrt_llm._utils import get_free_port, mpi_disabled
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -41,112 +43,6 @@ def cleanup_output_files():
|
||||
pass
|
||||
|
||||
|
||||
def validate_timing_metrics(perf_metrics_item, request_context=""):
|
||||
"""
|
||||
Helper function to validate timing metrics relationships.
|
||||
|
||||
Args:
|
||||
perf_metrics_item: A single performance metrics item from the /perf_metrics endpoint
|
||||
request_context: String context for error messages (e.g., "request 1", "streaming")
|
||||
"""
|
||||
# Validate basic structure
|
||||
required_keys = [
|
||||
"ctx_server", "gen_server", "ctx_perf_metrics", "gen_perf_metrics",
|
||||
"disagg_server_arrival_time", "disagg_server_first_token_time"
|
||||
]
|
||||
for key in required_keys:
|
||||
assert key in perf_metrics_item, f"Missing key: {key} in {request_context}"
|
||||
|
||||
assert perf_metrics_item["ctx_perf_metrics"][
|
||||
"ctx_request_id"] == perf_metrics_item["gen_perf_metrics"][
|
||||
"ctx_request_id"]
|
||||
|
||||
# Extract timing metrics
|
||||
ctx_metrics = perf_metrics_item["ctx_perf_metrics"]["perf_metrics"][
|
||||
"timing_metrics"]
|
||||
gen_metrics = perf_metrics_item["gen_perf_metrics"]["perf_metrics"][
|
||||
"timing_metrics"]
|
||||
disagg_arrival = perf_metrics_item["disagg_server_arrival_time"]
|
||||
disagg_first_token = perf_metrics_item["disagg_server_first_token_time"]
|
||||
|
||||
# Validate disaggregated server timing metrics
|
||||
assert disagg_arrival is not None, f"disagg_server_arrival_time is None in {request_context}"
|
||||
assert disagg_first_token is not None, f"disagg_server_first_token_time is None in {request_context}"
|
||||
assert isinstance(
|
||||
disagg_arrival,
|
||||
(int, float
|
||||
)), f"disagg_server_arrival_time is not numeric in {request_context}"
|
||||
assert isinstance(
|
||||
disagg_first_token, (int, float)
|
||||
), f"disagg_server_first_token_time is not numeric in {request_context}"
|
||||
assert disagg_arrival > 0, f"disagg_server_arrival_time is not positive in {request_context}"
|
||||
assert disagg_first_token > 0, f"disagg_server_first_token_time is not positive in {request_context}"
|
||||
assert disagg_arrival <= disagg_first_token, f"disagg_server_arrival_time > disagg_server_first_token_time in {request_context}"
|
||||
|
||||
# Validate server-level timing metrics for context server
|
||||
ctx_server_arrival = ctx_metrics.get("server_arrival_time")
|
||||
ctx_server_first_token = ctx_metrics.get("server_first_token_time")
|
||||
assert ctx_server_arrival is not None, f"ctx server_arrival_time is None in {request_context}"
|
||||
assert ctx_server_first_token is not None, f"ctx server_first_token_time is None in {request_context}"
|
||||
assert isinstance(
|
||||
ctx_server_arrival,
|
||||
(int,
|
||||
float)), f"ctx server_arrival_time is not numeric in {request_context}"
|
||||
assert isinstance(
|
||||
ctx_server_first_token,
|
||||
(int, float
|
||||
)), f"ctx server_first_token_time is not numeric in {request_context}"
|
||||
assert ctx_server_arrival <= ctx_server_first_token, f"ctx server_arrival_time > server_first_token_time in {request_context}"
|
||||
assert ctx_metrics["last_token_time"] - ctx_server_first_token < 1e-3
|
||||
|
||||
# Validate server-level timing metrics for generation server
|
||||
gen_server_arrival = gen_metrics.get("server_arrival_time")
|
||||
gen_server_first_token = gen_metrics.get("server_first_token_time")
|
||||
assert gen_server_arrival is not None, f"gen server_arrival_time is None in {request_context}"
|
||||
assert gen_server_first_token is not None, f"gen server_first_token_time is None in {request_context}"
|
||||
assert isinstance(
|
||||
gen_server_arrival,
|
||||
(int,
|
||||
float)), f"gen server_arrival_time is not numeric in {request_context}"
|
||||
assert isinstance(
|
||||
gen_server_first_token,
|
||||
(int, float
|
||||
)), f"gen server_first_token_time is not numeric in {request_context}"
|
||||
assert gen_server_arrival <= gen_server_first_token, f"gen server_arrival_time > server_first_token_time in {request_context}"
|
||||
|
||||
# Network Time Protocol can ensure ms-level accuracy in LAN
|
||||
ntp_tolerance = 1e-3
|
||||
|
||||
# Validate timing relationships between different levels
|
||||
# Disaggregated server should receive request before individual servers
|
||||
assert disagg_arrival - ntp_tolerance <= ctx_server_arrival, f"disagg_arrival > ctx_server_arrival in {request_context}"
|
||||
assert disagg_arrival - ntp_tolerance <= gen_server_arrival, f"disagg_arrival > gen_server_arrival in {request_context}"
|
||||
|
||||
# Context should complete before generation starts
|
||||
assert ctx_server_first_token - ntp_tolerance <= gen_server_arrival, f"ctx_server_first_token > gen_server_arrival in {request_context}"
|
||||
|
||||
# Validate internal timing consistency
|
||||
ctx_arrival_time = ctx_metrics["arrival_time"]
|
||||
ctx_first_token_time = ctx_metrics["first_token_time"]
|
||||
gen_arrival_time = gen_metrics["arrival_time"]
|
||||
gen_first_token_time = gen_metrics["first_token_time"]
|
||||
|
||||
assert ctx_arrival_time <= ctx_first_token_time, f"ctx arrival_time > first_token_time in {request_context}"
|
||||
assert gen_arrival_time <= gen_first_token_time, f"gen arrival_time > first_token_time in {request_context}"
|
||||
|
||||
# Test KV cache transfer timing (if present)
|
||||
if "kv_cache_transfer_start" in gen_metrics and "kv_cache_transfer_end" in gen_metrics:
|
||||
kv_start = gen_metrics["kv_cache_transfer_start"]
|
||||
kv_end = gen_metrics["kv_cache_transfer_end"]
|
||||
assert gen_metrics["kv_cache_size"] > 0
|
||||
assert kv_start <= kv_end, f"kv_cache_transfer_start > kv_cache_transfer_end in {request_context}"
|
||||
assert gen_arrival_time <= kv_start, f"gen_arrival_time > kv_cache_transfer_start in {request_context}"
|
||||
assert kv_end <= gen_metrics[
|
||||
"first_scheduled_time"], f"kv_cache_transfer_end > first_scheduled_time in {request_context}"
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_disagg_server_url_from_cfg(config_file: str) -> tuple[str, int]:
|
||||
with open(config_file, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
@ -828,16 +724,7 @@ def test_disaggregated_perf_metrics(disaggregated_test_root, llm_venv,
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
|
||||
def extra_endpoints_test(server_url: str):
|
||||
import json
|
||||
import urllib.request
|
||||
|
||||
with urllib.request.urlopen(f"{server_url}/perf_metrics",
|
||||
timeout=10) as resp:
|
||||
assert resp.status == 200
|
||||
perf_metrics = json.load(resp)
|
||||
assert len(perf_metrics) > 0
|
||||
item = perf_metrics[0]
|
||||
|
||||
item = get_timing_metrics(server_url)
|
||||
# Use helper function to validate all timing metrics comprehensively
|
||||
validate_timing_metrics(item, "perf_metrics test")
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ These tests verify that trtllm-serve handles error conditions gracefully:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@ -19,11 +18,7 @@ import requests
|
||||
from defs.conftest import llm_models_root
|
||||
from defs.trt_test_alternative import popen, print_error, print_info
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
|
||||
|
||||
class RemoteOpenAIServer:
|
||||
@ -63,7 +58,7 @@ def server(model_path):
|
||||
"""Start a test server for the module using popen like test_serve.py"""
|
||||
host_bind = "0.0.0.0"
|
||||
client_host = "localhost"
|
||||
port = _find_free_port()
|
||||
port = get_free_port()
|
||||
cmd = [
|
||||
"trtllm-serve",
|
||||
"serve",
|
||||
|
||||
@ -31,6 +31,7 @@ from _pytest.nodes import Item
|
||||
from _pytest.python import Function
|
||||
from defs.trt_test_alternative import (check_output, popen, print_error,
|
||||
print_info)
|
||||
from test_common.http_utils import wait_for_endpoint_ready
|
||||
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
|
||||
@ -251,29 +252,6 @@ class PerfAggrScriptTestCmds(NamedTuple):
|
||||
timeout: int
|
||||
output_dir: str
|
||||
|
||||
def wait_for_endpoint_ready(self, url: str, timeout: int = 7200):
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
elapsed_time = time.monotonic() - start
|
||||
if elapsed_time > timeout:
|
||||
print_error(
|
||||
f"Timeout waiting for endpoint {url} to be ready after {timeout} seconds"
|
||||
)
|
||||
break
|
||||
try:
|
||||
print_info(
|
||||
f"Waiting for endpoint {url} to be ready, elapsed time: {elapsed_time}s"
|
||||
)
|
||||
time.sleep(1)
|
||||
if requests.get(url).status_code == 200:
|
||||
print_info(f"endpoint {url} is ready")
|
||||
return
|
||||
except Exception as err:
|
||||
print_info(
|
||||
f"endpoint {url} is not ready, with exception: {err}")
|
||||
print_error(
|
||||
f"Endpoint {url} did not become ready within {timeout} seconds")
|
||||
|
||||
def run_cmd(self, cmd_idx: int, venv) -> str:
|
||||
output = ""
|
||||
server_proc = None
|
||||
@ -294,7 +272,7 @@ class PerfAggrScriptTestCmds(NamedTuple):
|
||||
stderr=subprocess.STDOUT,
|
||||
env=copy.deepcopy(os.environ),
|
||||
)
|
||||
self.wait_for_endpoint_ready(
|
||||
wait_for_endpoint_ready(
|
||||
f"http://{server_hostname}:{server_port}/health",
|
||||
timeout=self.timeout)
|
||||
client_cmd = add_host_port_to_cmd(self.client_cmds[cmd_idx],
|
||||
@ -323,19 +301,6 @@ class PerfDisaggScriptTestCmds(NamedTuple):
|
||||
client_cmd: List[str]
|
||||
benchmark_cmd: List[str]
|
||||
|
||||
def wait_for_endpoint_ready(self, url: str, timeout: int = 600):
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
time.sleep(1)
|
||||
if requests.get(url).status_code == 200:
|
||||
print(f"endpoint {url} is ready")
|
||||
return
|
||||
except Exception as err:
|
||||
print(f"endpoint {url} is not ready, with exception: {err}")
|
||||
print_error(
|
||||
f"Endpoint {url} did not become ready within {timeout} seconds")
|
||||
|
||||
def run_cmd(self, cmd_idx: int, venv) -> str:
|
||||
output = ""
|
||||
try:
|
||||
@ -360,7 +325,7 @@ class PerfDisaggScriptTestCmds(NamedTuple):
|
||||
stderr=subprocess.STDOUT,
|
||||
env=venv._new_env,
|
||||
shell=True) as server_proc):
|
||||
self.wait_for_endpoint_ready(
|
||||
wait_for_endpoint_ready(
|
||||
f"http://localhost:8000/health",
|
||||
timeout=1800) # 30 minutes for large models
|
||||
check_output(self.client_cmd, env=venv._new_env)
|
||||
|
||||
@ -5,7 +5,7 @@ threadleak_exclude = asyncio_\d+
|
||||
junit_family=legacy
|
||||
addopts = --ignore-glob="*perf/test_perf.py" --ignore-glob="*perf/disagg/*" --ignore-glob="*test_list_validation.py" --ignore-glob="*llm-test-workspace*" --durations=0 -W ignore::DeprecationWarning
|
||||
pythonpath =
|
||||
../../../examples/auto_deploy
|
||||
../../../examples/auto_deploy ../../
|
||||
norecursedirs = ./triton/perf ./perf/disagg
|
||||
markers =
|
||||
skip_less_device: skip when less device detected than the declared
|
||||
|
||||
@ -1769,6 +1769,21 @@ def test_trtllm_multimodal_benchmark_serving(llm_root, llm_venv):
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
@pytest.mark.parametrize("service_discovery", ["etcd", "http"])
|
||||
def test_openai_disagg_multi_nodes_completion_service_discovery(
|
||||
llm_root, llm_venv, service_discovery):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd([
|
||||
"-m",
|
||||
"pytest",
|
||||
str(test_root /
|
||||
f"_test_disagg_serving_multi_nodes_service_discovery.py::test_completion[{service_discovery}]"
|
||||
),
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
@pytest.mark.parametrize("gen_config",
|
||||
|
||||
@ -11,3 +11,5 @@ test_e2e.py::test_multi_nodes_eval[Kimi-K2-Instruct-tp16-mmlu]
|
||||
test_e2e.py::test_multi_nodes_eval[nemotron-nas/Llama-3_1-Nemotron-Ultra-253B-v1-tp16-mmlu]
|
||||
test_e2e.py::test_openai_disagg_multi_nodes_completion[ctx_tp2pp1-gen_tp2pp1]
|
||||
test_e2e.py::test_openai_disagg_multi_nodes_completion[ctx_tp1pp2-gen_tp1pp2]
|
||||
test_e2e.py::test_openai_disagg_multi_nodes_completion_service_discovery[http]
|
||||
test_e2e.py::test_openai_disagg_multi_nodes_completion_service_discovery[etcd]
|
||||
|
||||
@ -42,6 +42,7 @@ l0_dgx_h100:
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-False-True]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-False]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True]
|
||||
- unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
|
||||
# ------------- AutoDeploy tests ---------------
|
||||
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2]
|
||||
# llmapi
|
||||
|
||||
@ -4,12 +4,11 @@ import ast
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, NamedTuple
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from test_common.http_utils import wait_for_endpoint_ready
|
||||
|
||||
|
||||
def get_node_name() -> str:
|
||||
@ -568,19 +567,6 @@ class PerfServerBenchmarkCmds(NamedTuple):
|
||||
names: List[str]
|
||||
working_dir: str
|
||||
|
||||
def wait_for_endpoint_ready(self, url: str, timeout: int = 5400):
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
time.sleep(10)
|
||||
if requests.get(url, timeout=5).status_code == 200:
|
||||
print(f"endpoint {url} is ready")
|
||||
return
|
||||
except Exception as err:
|
||||
print(f"endpoint {url} is not ready, with exception: {err}")
|
||||
print_error(
|
||||
f"Endpoint {url} did not become ready within {timeout} seconds")
|
||||
|
||||
def run_cmd(self,
|
||||
cmd_idx: int,
|
||||
node_name: str,
|
||||
@ -601,8 +587,8 @@ class PerfServerBenchmarkCmds(NamedTuple):
|
||||
stderr=subprocess.STDOUT)
|
||||
|
||||
# Wait for server to be ready
|
||||
self.wait_for_endpoint_ready("http://localhost:8000/v1/models",
|
||||
timeout=max_timeout)
|
||||
wait_for_endpoint_ready("http://localhost:8000/v1/models",
|
||||
timeout=max_timeout)
|
||||
|
||||
# Save node name, gpu info, server config, client config output to server file path
|
||||
with open(client_file_path, 'w') as client_ctx:
|
||||
|
||||
0
tests/test_common/__init__.py
Normal file
0
tests/test_common/__init__.py
Normal file
29
tests/test_common/http_utils.py
Normal file
29
tests/test_common/http_utils.py
Normal file
@ -0,0 +1,29 @@
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def wait_for_endpoint_ready(url: str, timeout: int = 300):
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
time.sleep(1)
|
||||
if requests.get(url, timeout=5).status_code == 200:
|
||||
print(f"endpoint {url} is ready")
|
||||
return
|
||||
except Exception as err:
|
||||
print(f"endpoint {url} is not ready, with exception: {err}")
|
||||
raise RuntimeError(f"Endpoint {url} did not become ready within {timeout} seconds")
|
||||
|
||||
|
||||
def wait_for_endpoint_down(url: str, timeout: int = 300):
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
if requests.get(url, timeout=5).status_code >= 100:
|
||||
print(f"endpoint {url} returned status code {requests.get(url).status_code}")
|
||||
time.sleep(1)
|
||||
except Exception as err:
|
||||
print(f"endpoint {url} is down, with exception: {err}")
|
||||
return
|
||||
raise RuntimeError(f"Endpoint {url} did not become down within {timeout} seconds")
|
||||
188
tests/test_common/perf_metrics_utils.py
Normal file
188
tests/test_common/perf_metrics_utils.py
Normal file
@ -0,0 +1,188 @@
|
||||
import requests
|
||||
|
||||
|
||||
def get_timing_metrics(server_url: str):
|
||||
response = requests.get(f"{server_url}/perf_metrics", timeout=10)
|
||||
assert response.status_code == 200
|
||||
perf_metrics = response.json()
|
||||
assert len(perf_metrics) > 0
|
||||
return perf_metrics[0]
|
||||
|
||||
|
||||
def validate_timing_metrics(perf_metrics_item, request_context="", time_tolerance_seconds=0.005):
|
||||
"""Helper function to validate timing metrics relationships.
|
||||
|
||||
Args:
|
||||
perf_metrics_item: A single performance metrics item from the /perf_metrics endpoint
|
||||
request_context: String context for error messages (e.g., "request 1", "streaming")
|
||||
"""
|
||||
# Validate basic structure
|
||||
required_keys = [
|
||||
"ctx_server",
|
||||
"gen_server",
|
||||
"ctx_perf_metrics",
|
||||
"gen_perf_metrics",
|
||||
"disagg_server_arrival_time",
|
||||
"disagg_server_first_token_time",
|
||||
]
|
||||
for key in required_keys:
|
||||
assert key in perf_metrics_item, f"Missing key: {key} in {request_context}"
|
||||
|
||||
assert (
|
||||
perf_metrics_item["ctx_perf_metrics"]["ctx_request_id"]
|
||||
== perf_metrics_item["gen_perf_metrics"]["ctx_request_id"]
|
||||
)
|
||||
|
||||
# Extract timing metrics
|
||||
ctx_metrics = perf_metrics_item["ctx_perf_metrics"]["perf_metrics"]["timing_metrics"]
|
||||
gen_metrics = perf_metrics_item["gen_perf_metrics"]["perf_metrics"]["timing_metrics"]
|
||||
disagg_arrival = perf_metrics_item["disagg_server_arrival_time"]
|
||||
disagg_first_token = perf_metrics_item["disagg_server_first_token_time"]
|
||||
|
||||
# Validate disaggregated server timing metrics
|
||||
assert disagg_arrival is not None, f"disagg_server_arrival_time is None in {request_context}"
|
||||
assert disagg_first_token is not None, (
|
||||
f"disagg_server_first_token_time is None in {request_context}"
|
||||
)
|
||||
assert isinstance(disagg_arrival, (int, float)), (
|
||||
f"disagg_server_arrival_time is not numeric in {request_context}"
|
||||
)
|
||||
assert isinstance(disagg_first_token, (int, float)), (
|
||||
f"disagg_server_first_token_time is not numeric in {request_context}"
|
||||
)
|
||||
assert disagg_arrival > 0, f"disagg_server_arrival_time is not positive in {request_context}"
|
||||
assert disagg_first_token > 0, (
|
||||
f"disagg_server_first_token_time is not positive in {request_context}"
|
||||
)
|
||||
assert disagg_arrival <= disagg_first_token, (
|
||||
f"disagg_server_arrival_time > disagg_server_first_token_time in {request_context}"
|
||||
)
|
||||
|
||||
# Validate server-level timing metrics for context server
|
||||
ctx_server_arrival = ctx_metrics.get("server_arrival_time")
|
||||
ctx_server_first_token = ctx_metrics.get("server_first_token_time")
|
||||
assert ctx_server_arrival is not None, f"ctx server_arrival_time is None in {request_context}"
|
||||
assert ctx_server_first_token is not None, (
|
||||
f"ctx server_first_token_time is None in {request_context}"
|
||||
)
|
||||
assert isinstance(ctx_server_arrival, (int, float)), (
|
||||
f"ctx server_arrival_time is not numeric in {request_context}"
|
||||
)
|
||||
assert isinstance(ctx_server_first_token, (int, float)), (
|
||||
f"ctx server_first_token_time is not numeric in {request_context}"
|
||||
)
|
||||
assert ctx_server_arrival <= ctx_server_first_token, (
|
||||
f"ctx server_arrival_time > server_first_token_time in {request_context}"
|
||||
)
|
||||
assert ctx_metrics["last_token_time"] - ctx_server_first_token < 1e-3
|
||||
|
||||
# Validate server-level timing metrics for generation server
|
||||
gen_server_arrival = gen_metrics.get("server_arrival_time")
|
||||
gen_server_first_token = gen_metrics.get("server_first_token_time")
|
||||
assert gen_server_arrival is not None, f"gen server_arrival_time is None in {request_context}"
|
||||
assert gen_server_first_token is not None, (
|
||||
f"gen server_first_token_time is None in {request_context}"
|
||||
)
|
||||
assert isinstance(gen_server_arrival, (int, float)), (
|
||||
f"gen server_arrival_time is not numeric in {request_context}"
|
||||
)
|
||||
assert isinstance(gen_server_first_token, (int, float)), (
|
||||
f"gen server_first_token_time is not numeric in {request_context}"
|
||||
)
|
||||
assert gen_server_arrival <= gen_server_first_token, (
|
||||
f"gen server_arrival_time > server_first_token_time in {request_context}"
|
||||
)
|
||||
|
||||
# Validate timing relationships between different levels
|
||||
# Disaggregated server should receive request before individual servers
|
||||
# Allow some tolerance of a local network ping time when comparing the times from disagg and ctx/gen servers
|
||||
# by taking consideration of the error of NTP (1/2 ping time).
|
||||
assert disagg_arrival <= ctx_server_arrival + time_tolerance_seconds, (
|
||||
f"disagg_arrival {disagg_arrival} > ctx_server_arrival {ctx_server_arrival} in {request_context}"
|
||||
)
|
||||
assert disagg_arrival <= gen_server_arrival + time_tolerance_seconds, (
|
||||
f"disagg_arrival {disagg_arrival} > gen_server_arrival {gen_server_arrival} in {request_context}"
|
||||
)
|
||||
|
||||
# Context should complete before generation starts
|
||||
assert ctx_server_first_token <= gen_server_arrival + time_tolerance_seconds, (
|
||||
f"ctx_server_first_token > gen_server_arrival in {request_context}"
|
||||
)
|
||||
|
||||
# Validate internal timing consistency
|
||||
ctx_arrival_time = ctx_metrics["arrival_time"]
|
||||
ctx_first_token_time = ctx_metrics["first_token_time"]
|
||||
gen_arrival_time = gen_metrics["arrival_time"]
|
||||
gen_first_token_time = gen_metrics["first_token_time"]
|
||||
|
||||
assert ctx_arrival_time <= ctx_first_token_time, (
|
||||
f"ctx arrival_time > first_token_time in {request_context}"
|
||||
)
|
||||
assert gen_arrival_time <= gen_first_token_time, (
|
||||
f"gen arrival_time > first_token_time in {request_context}"
|
||||
)
|
||||
|
||||
# Test KV cache transfer timing (if present)
|
||||
if "kv_cache_transfer_start" in gen_metrics and "kv_cache_transfer_end" in gen_metrics:
|
||||
kv_start = gen_metrics["kv_cache_transfer_start"]
|
||||
kv_end = gen_metrics["kv_cache_transfer_end"]
|
||||
assert gen_metrics["kv_cache_size"] > 0
|
||||
assert kv_start <= kv_end, (
|
||||
f"kv_cache_transfer_start > kv_cache_transfer_end in {request_context}"
|
||||
)
|
||||
assert gen_arrival_time <= kv_start, (
|
||||
f"gen_arrival_time > kv_cache_transfer_start in {request_context}"
|
||||
)
|
||||
assert kv_end <= gen_metrics["first_scheduled_time"], (
|
||||
f"kv_cache_transfer_end > first_scheduled_time in {request_context}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_prometheus_metrics(server_url: str):
|
||||
response = requests.get(server_url + "/prometheus/metrics")
|
||||
assert response.status_code == 200
|
||||
# Parse Prometheus metrics lines into a dictionary of {metric_name: value}
|
||||
metrics = {}
|
||||
print(response.text)
|
||||
for line in response.text.split("\n"):
|
||||
if line.startswith("#") or not line.strip():
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
metric = parts[0]
|
||||
try:
|
||||
value = float(parts[1])
|
||||
except ValueError:
|
||||
continue
|
||||
import re
|
||||
|
||||
if bucket_match := re.match(r'(.+)_bucket\{le="([^"]+)"\}', metric):
|
||||
# Try to parse bucket boundaries out of metrics like ..._bucket{le="0.005"}
|
||||
base_metric, le_value = bucket_match.groups()
|
||||
if base_metric not in metrics:
|
||||
metrics[base_metric] = {}
|
||||
try:
|
||||
metrics[base_metric][float(le_value)] = value
|
||||
except ValueError:
|
||||
continue
|
||||
elif sum_match := re.match(r"(.+)_sum$", metric):
|
||||
base_metric = sum_match.groups()[0]
|
||||
if base_metric not in metrics:
|
||||
metrics[base_metric] = {}
|
||||
metrics[base_metric]["sum"] = value
|
||||
elif count_match := re.match(r"(.+)_count$", metric):
|
||||
base_metric = count_match.groups()[0]
|
||||
if base_metric not in metrics:
|
||||
metrics[base_metric] = {}
|
||||
metrics[base_metric]["count"] = value
|
||||
elif total_match := re.match(r"(.+)_total$", metric):
|
||||
base_metric = total_match.groups()[0]
|
||||
print(f"Total metric {metric}: {base_metric} = {value}")
|
||||
metrics[base_metric] = value
|
||||
else:
|
||||
# ignore prometheus built-in metrics
|
||||
pass
|
||||
return metrics
|
||||
@ -4,11 +4,15 @@ import time
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
from test_common.http_utils import (wait_for_endpoint_down,
|
||||
wait_for_endpoint_ready)
|
||||
from test_common.perf_metrics_utils import (get_timing_metrics,
|
||||
validate_timing_metrics)
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteDisaggOpenAIServer, RemoteOpenAIServer
|
||||
from .utils import expand_slurm_nodelist
|
||||
from .utils import (expand_slurm_nodelist, wait_for_endpoint_down,
|
||||
wait_for_endpoint_ready)
|
||||
|
||||
RANK = int(os.environ.get("SLURM_PROCID", 0))
|
||||
NODE_RANK = int(os.environ.get("SLURM_NODEID", 0))
|
||||
@ -19,7 +23,8 @@ pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
# This test assumes that there are >2 nodes, we run ctx/disagg-server/client on the first node,
|
||||
# and run gen the second node.
|
||||
|
||||
# This is a multi-node test, and will not be scheduled to the same node running other tests
|
||||
# using fixed ports should be safe.
|
||||
CTX_SERVER_PORT = 8001
|
||||
GEN_SERVER_PORT = 8002
|
||||
DISAGG_SERVER_PORT = 8000
|
||||
@ -65,6 +70,7 @@ def env():
|
||||
k: v
|
||||
for k, v in os.environ.items()
|
||||
if not ('PMI_' in k or 'OMPI_' in k or 'PMIX_' in k or 'SLURM_' in k)
|
||||
and k not in ["UCX_TLS", "UCX_NET_DEVICES"] # avoid UCX failure on oci
|
||||
}
|
||||
|
||||
|
||||
@ -105,6 +111,8 @@ def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
|
||||
"enable_block_reuse": False,
|
||||
},
|
||||
"disable_overlap_scheduler": True,
|
||||
"perf_metrics_max_requests": 1000,
|
||||
"return_perf_metrics": True,
|
||||
}
|
||||
if is_ctx_node():
|
||||
print(f"starting ctx_server for rank {RANK} node rank {NODE_RANK}")
|
||||
@ -138,32 +146,6 @@ def worker(model_name: str, ctx_tp_pp_size: tuple, gen_tp_pp_size: tuple):
|
||||
yield None
|
||||
|
||||
|
||||
def wait_for_endpoint_ready(url: str, timeout: int = 300):
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
time.sleep(1)
|
||||
if requests.get(url).status_code == 200:
|
||||
print(f"endpoint {url} is ready")
|
||||
return
|
||||
except Exception as err:
|
||||
print(f"endpoint {url} is not ready, with exception: {err}")
|
||||
|
||||
|
||||
def wait_for_endpoint_down(url: str, timeout: int = 300):
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
if requests.get(url).status_code >= 100:
|
||||
print(
|
||||
f"endpoint {url} returned status code {requests.get(url).status_code}"
|
||||
)
|
||||
time.sleep(1)
|
||||
except Exception as err:
|
||||
print(f"endpoint {url} is down, with exception: {err}")
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def disagg_server(worker: RemoteOpenAIServer):
|
||||
if is_disagg_node():
|
||||
@ -210,6 +192,14 @@ def test_completion(client: openai.OpenAI,
|
||||
assert completion.id is not None
|
||||
message = completion.choices[0].text
|
||||
assert message.startswith('2.')
|
||||
|
||||
perf_metrics = get_timing_metrics(disagg_server.url_root)
|
||||
# allow 5ms leniency when comparing the time points from disagg and ctx/gen servers
|
||||
validate_timing_metrics(perf_metrics,
|
||||
"multinode test_completion",
|
||||
time_leniency_seconds=0.005)
|
||||
# sleep 10 seconds to ensure a successful wait_for_endpoint_ready on rank1
|
||||
time.sleep(10)
|
||||
disagg_server.terminate()
|
||||
|
||||
elif is_gen_node():
|
||||
|
||||
@ -0,0 +1,220 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from test_common.perf_metrics_utils import get_timing_metrics, validate_timing_metrics
|
||||
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
from tensorrt_llm.llmapi.disagg_utils import ServerRole
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteDisaggOpenAIServer, RemoteOpenAIServer
|
||||
from .utils import expand_slurm_nodelist, wait_for_endpoint_down, wait_for_endpoint_ready
|
||||
|
||||
RANK = int(os.environ.get("SLURM_PROCID", 0))
|
||||
NODE_RANK = int(os.environ.get("SLURM_NODEID", 0))
|
||||
NODE_LIST = expand_slurm_nodelist(os.environ.get("SLURM_NODELIST", ""))
|
||||
SLURM_NTASKS_PER_NODE = int(os.environ.get("SLURM_NTASKS_PER_NODE", 1))
|
||||
|
||||
# This a multi-node QA test, use a fixed port instead of finding a free port
|
||||
# so that all nodes can have the same disagg server config
|
||||
DISAGG_SERVER_PORT = 8000
|
||||
|
||||
|
||||
# This test is supposed to run with 2 nodes or more
|
||||
def is_ctx_node():
|
||||
assert len(NODE_LIST) == 2
|
||||
return NODE_RANK == 0
|
||||
|
||||
|
||||
def is_gen_node():
|
||||
assert len(NODE_LIST) == 2
|
||||
return NODE_RANK == 1
|
||||
|
||||
|
||||
def is_disagg_node():
|
||||
return NODE_RANK == 0
|
||||
|
||||
|
||||
# The test is run on multinodes but only the first node's output is used for assertion
|
||||
def is_pytest_node():
|
||||
return NODE_RANK == 0
|
||||
|
||||
|
||||
def env():
|
||||
# Remove MPI related environment variables to isolate the ctx/gen processes
|
||||
# so that they will not be in the same MPI communicator, otherwise the rank and world_size may mismatch
|
||||
return {
|
||||
k: v
|
||||
for k, v in os.environ.items()
|
||||
if not ("PMI_" in k or "OMPI_" in k or "PMIX_" in k or "SLURM_" in k)
|
||||
and k not in ["UCX_TLS", "UCX_NET_DEVICES"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disagg_host():
|
||||
return NODE_LIST[0]
|
||||
|
||||
|
||||
@pytest.fixture(params=["etcd", "http"])
|
||||
def service_discovery(request, disagg_host: str):
|
||||
if request.param == "etcd":
|
||||
work_dir = tempfile.mkdtemp()
|
||||
data_dir = f"{work_dir}/disagg_test-etcd-{uuid.uuid4()}"
|
||||
etcd = subprocess.Popen(["etcd", "--data-dir", data_dir])
|
||||
yield etcd, f"etcd://{disagg_host}:2379"
|
||||
try:
|
||||
etcd.kill()
|
||||
etcd.wait(timeout=10)
|
||||
shutil.rmtree(data_dir)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
yield None, f"http://{disagg_host}:{DISAGG_SERVER_PORT}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disagg_cluster_config(service_discovery: tuple):
|
||||
_, uri = service_discovery
|
||||
return {
|
||||
"cluster_uri": uri,
|
||||
"cluster_name": "",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def worker(model_name: str, disagg_cluster_config: dict):
|
||||
extra_config = {
|
||||
"disagg_cluster": disagg_cluster_config,
|
||||
"cache_transceiver_config": {"backend": "DEFAULT"},
|
||||
"kv_cache_config": {
|
||||
"free_gpu_memory_fraction": 0.5,
|
||||
"enable_block_reuse": False,
|
||||
},
|
||||
"disable_overlap_scheduler": True,
|
||||
"return_perf_metrics": True,
|
||||
"perf_metrics_max_requests": 1000,
|
||||
}
|
||||
# start workers on 0.0.0.0:<free_port>, then the workers should be able to
|
||||
# report their correct hostname:port to the disagg server
|
||||
port = get_free_port()
|
||||
if is_ctx_node():
|
||||
print(f"starting ctx_server for rank {RANK} node rank {NODE_RANK}")
|
||||
model_path = get_model_path(model_name)
|
||||
tp_size, pp_size = 1, 1
|
||||
args = ["--tp_size", str(tp_size), "--pp_size", str(pp_size)]
|
||||
with RemoteOpenAIServer(
|
||||
model_path,
|
||||
port=port,
|
||||
cli_args=args,
|
||||
host="0.0.0.0",
|
||||
env=env(),
|
||||
llmapi_launch=False,
|
||||
rank=RANK % SLURM_NTASKS_PER_NODE,
|
||||
extra_config=extra_config,
|
||||
role=ServerRole.CONTEXT,
|
||||
) as server:
|
||||
yield server
|
||||
elif is_gen_node():
|
||||
print(f"starting gen_server for rank {RANK} node rank {NODE_RANK}")
|
||||
model_path = get_model_path(model_name)
|
||||
tp_size, pp_size = 1, 1
|
||||
args = ["--tp_size", str(tp_size), "--pp_size", str(pp_size)]
|
||||
with RemoteOpenAIServer(
|
||||
model_path,
|
||||
port=port,
|
||||
cli_args=args,
|
||||
host="0.0.0.0",
|
||||
env=env(),
|
||||
llmapi_launch=False,
|
||||
rank=RANK % SLURM_NTASKS_PER_NODE,
|
||||
extra_config=extra_config,
|
||||
role=ServerRole.GENERATION,
|
||||
) as server:
|
||||
yield server
|
||||
else:
|
||||
yield None
|
||||
|
||||
|
||||
# different from non-service-discovery version, disagg server doesn't have to
|
||||
# wait for ctx/gen servers to get ready
|
||||
@pytest.fixture
|
||||
def disagg_server(disagg_cluster_config: dict):
|
||||
if is_disagg_node():
|
||||
disagg_config = {
|
||||
"disagg_cluster": disagg_cluster_config,
|
||||
"port": DISAGG_SERVER_PORT,
|
||||
"hostname": "0.0.0.0",
|
||||
"perf_metrics_max_requests": 1000,
|
||||
}
|
||||
print(f"starting disagg_server for rank {RANK} node rank {NODE_RANK}")
|
||||
# ctx/gen servers are unnecessary for service discovery test
|
||||
with RemoteDisaggOpenAIServer(
|
||||
ctx_servers=[],
|
||||
gen_servers=[],
|
||||
port=DISAGG_SERVER_PORT,
|
||||
disagg_config=disagg_config,
|
||||
llmapi_launch=False,
|
||||
env=env(),
|
||||
wait_ready=False, # wait it to be ready in test body
|
||||
) as server:
|
||||
yield server
|
||||
else:
|
||||
print(f"skipping disagg_server for rank {RANK} node rank {NODE_RANK}")
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(disagg_server: RemoteDisaggOpenAIServer):
|
||||
if is_pytest_node():
|
||||
return disagg_server.get_client()
|
||||
else:
|
||||
print(f"skipping client for rank {RANK} node rank {NODE_RANK}")
|
||||
return None
|
||||
|
||||
|
||||
def test_completion(
|
||||
disagg_server: RemoteDisaggOpenAIServer,
|
||||
worker: RemoteOpenAIServer,
|
||||
client: openai.OpenAI,
|
||||
disagg_host: str,
|
||||
model_name: str,
|
||||
):
|
||||
disagg_health_url = f"http://{disagg_host}:{DISAGG_SERVER_PORT}/health/"
|
||||
wait_for_endpoint_ready(disagg_health_url)
|
||||
if is_pytest_node():
|
||||
print(f"running test_completion on rank {RANK} node rank {NODE_RANK}")
|
||||
prompt = "What is the result of 1+1? Answer in one word: "
|
||||
for _ in range(10):
|
||||
completion = client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
print(f"Output: {completion.choices[0].text}")
|
||||
assert completion.id is not None
|
||||
message = completion.choices[0].text
|
||||
assert message.startswith("2.")
|
||||
|
||||
perf_metrics = get_timing_metrics(disagg_server.url_root)
|
||||
validate_timing_metrics(perf_metrics, "multinode test_completion")
|
||||
|
||||
disagg_server.terminate()
|
||||
|
||||
elif is_gen_node():
|
||||
# keep gen workers alive until the test ends
|
||||
wait_for_endpoint_down(disagg_health_url)
|
||||
assert True
|
||||
else:
|
||||
assert True
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
@ -8,42 +7,11 @@ import requests
|
||||
import yaml
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
from .openai_server import RemoteMMEncoderServer
|
||||
|
||||
pytestmark = pytest.mark.threadleak(enabled=False)
|
||||
|
||||
|
||||
class RemoteMMEncoderServer(RemoteOpenAIServer):
|
||||
"""Remote server for testing multimodal encoder endpoints."""
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
cli_args: List[str] = None,
|
||||
port: int = None) -> None:
|
||||
# Reuse parent initialization but change the command
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from tensorrt_llm.llmapi.mpi_session import find_free_port
|
||||
|
||||
self.host = "localhost"
|
||||
self.port = port if port is not None else find_free_port()
|
||||
self.rank = os.environ.get("SLURM_PROCID", 0)
|
||||
|
||||
args = ["--host", f"{self.host}", "--port", f"{self.port}"]
|
||||
if cli_args:
|
||||
args += cli_args
|
||||
|
||||
# Use mm_embedding_serve command instead of regular serve
|
||||
launch_cmd = ["trtllm-serve", "mm_embedding_serve"] + [model] + args
|
||||
|
||||
self.proc = subprocess.Popen(launch_cmd,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr)
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=self.MAX_SERVER_START_WAIT_S)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", ids=["Qwen2.5-VL-3B-Instruct"])
|
||||
def model_name():
|
||||
return "Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
@ -11,7 +11,8 @@ import openai
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from tensorrt_llm.llmapi.mpi_session import find_free_port
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
from tensorrt_llm.llmapi.disagg_utils import ServerRole
|
||||
|
||||
|
||||
class RemoteOpenAIServer:
|
||||
@ -26,13 +27,21 @@ class RemoteOpenAIServer:
|
||||
host: str = "localhost",
|
||||
env: Optional[dict] = None,
|
||||
rank: int = -1,
|
||||
extra_config: Optional[dict] = None) -> None:
|
||||
extra_config: Optional[dict] = None,
|
||||
log_path: Optional[str] = None,
|
||||
wait: bool = True,
|
||||
role: Optional[ServerRole] = None) -> None:
|
||||
self.host = host
|
||||
self.port = port if port is not None else find_free_port()
|
||||
self.port = port if port is not None else get_free_port()
|
||||
self.rank = rank if rank != -1 else int(
|
||||
os.environ.get("SLURM_PROCID", 0))
|
||||
self.extra_config_file = None
|
||||
self.log_path = log_path
|
||||
self.log_file = None
|
||||
self.role = role
|
||||
args = ["--host", f"{self.host}", "--port", f"{self.port}"]
|
||||
if self.role is not None:
|
||||
args += ["--server_role", self.role.name]
|
||||
if cli_args:
|
||||
args += cli_args
|
||||
if extra_config:
|
||||
@ -50,10 +59,19 @@ class RemoteOpenAIServer:
|
||||
env = os.environ.copy()
|
||||
self.proc = subprocess.Popen(launch_cmd,
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr)
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=self.MAX_SERVER_START_WAIT_S)
|
||||
stdout=self._get_output(),
|
||||
stderr=self._get_output())
|
||||
if wait:
|
||||
self.wait_for_server(timeout=self.MAX_SERVER_START_WAIT_S)
|
||||
|
||||
def _get_output(self):
|
||||
if self.log_file:
|
||||
return self.log_file
|
||||
elif self.log_path:
|
||||
self.log_file = open(self.log_path, "w+")
|
||||
return self.log_file
|
||||
else:
|
||||
return sys.stdout
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@ -76,6 +94,12 @@ class RemoteOpenAIServer:
|
||||
except Exception as e:
|
||||
print(f"Error removing extra config file: {e}")
|
||||
self.proc = None
|
||||
if self.log_file:
|
||||
self.log_file.close()
|
||||
self.log_file = None
|
||||
|
||||
def wait_for_server(self, timeout: float):
|
||||
self._wait_for_server(url=self.url_for("health"), timeout=timeout)
|
||||
|
||||
def _wait_for_server(self, *, url: str, timeout: float):
|
||||
# run health check on the first rank only.
|
||||
@ -128,21 +152,28 @@ class RemoteDisaggOpenAIServer(RemoteOpenAIServer):
|
||||
gen_servers: List[str],
|
||||
port: int = -1,
|
||||
env: Optional[dict] = None,
|
||||
llmapi_launch: bool = False) -> None:
|
||||
llmapi_launch: bool = False,
|
||||
disagg_config: Optional[dict] = None,
|
||||
log_path: Optional[str] = None,
|
||||
wait_ready: bool = True) -> None:
|
||||
self.ctx_servers = ctx_servers
|
||||
self.gen_servers = gen_servers
|
||||
self.host = "localhost"
|
||||
self.port = find_free_port() if port is None or port < 0 else port
|
||||
self.host = "0.0.0.0"
|
||||
self.port = get_free_port() if port is None or port < 0 else port
|
||||
self.rank = 0
|
||||
with tempfile.NamedTemporaryFile(mode="w+",
|
||||
delete=False,
|
||||
delete_on_close=False) as f:
|
||||
f.write(self._get_extra_config())
|
||||
f.flush()
|
||||
self.extra_config_file = f.name
|
||||
self.disagg_config = self._get_extra_config()
|
||||
if disagg_config:
|
||||
self.disagg_config.update(disagg_config)
|
||||
self.log_path = log_path
|
||||
self.log_file = None
|
||||
self.extra_config_file = os.path.join(
|
||||
tempfile.gettempdir(), f"disagg_config_{self.port}.yaml")
|
||||
with open(self.extra_config_file, "w+") as f:
|
||||
yaml.dump(self.disagg_config, f)
|
||||
launch_cmd = [
|
||||
"trtllm-serve", "disaggregated", "-c", self.extra_config_file
|
||||
]
|
||||
print(f"launch_cmd: {launch_cmd}, extra_config: {self.disagg_config}")
|
||||
if llmapi_launch:
|
||||
# start server with llmapi-launch on multi nodes
|
||||
launch_cmd = ["trtllm-llmapi-launch"] + launch_cmd
|
||||
@ -150,13 +181,14 @@ class RemoteDisaggOpenAIServer(RemoteOpenAIServer):
|
||||
env = os.environ.copy()
|
||||
self.proc = subprocess.Popen(launch_cmd,
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr)
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=self.MAX_SERVER_START_WAIT_S)
|
||||
stdout=self._get_output(),
|
||||
stderr=self._get_output())
|
||||
if wait_ready:
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=self.MAX_SERVER_START_WAIT_S)
|
||||
|
||||
def _get_extra_config(self):
|
||||
return yaml.dump({
|
||||
return {
|
||||
"context_servers": {
|
||||
"num_instances": len(self.ctx_servers),
|
||||
"urls": self.ctx_servers
|
||||
@ -167,4 +199,38 @@ class RemoteDisaggOpenAIServer(RemoteOpenAIServer):
|
||||
},
|
||||
"port": self.port,
|
||||
"hostname": self.host,
|
||||
})
|
||||
"perf_metrics_max_requests": 1000,
|
||||
}
|
||||
|
||||
|
||||
class RemoteMMEncoderServer(RemoteOpenAIServer):
|
||||
"""Remote server for testing multimodal encoder endpoints."""
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
cli_args: List[str] = None,
|
||||
port: int = None,
|
||||
log_path: Optional[str] = None) -> None:
|
||||
# Reuse parent initialization but change the command
|
||||
import subprocess
|
||||
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
|
||||
self.host = "localhost"
|
||||
self.port = port if port is not None else get_free_port()
|
||||
self.rank = os.environ.get("SLURM_PROCID", 0)
|
||||
self.log_path = log_path
|
||||
self.log_file = None
|
||||
|
||||
args = ["--host", f"{self.host}", "--port", f"{self.port}"]
|
||||
if cli_args:
|
||||
args += cli_args
|
||||
|
||||
# Use mm_embedding_serve command instead of regular serve
|
||||
launch_cmd = ["trtllm-serve", "mm_embedding_serve"] + [model] + args
|
||||
|
||||
self.proc = subprocess.Popen(launch_cmd,
|
||||
stdout=self._get_output(),
|
||||
stderr=self._get_output())
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=self.MAX_SERVER_START_WAIT_S)
|
||||
|
||||
219
tests/unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
Normal file
219
tests/unittest/llmapi/apps/test_disagg_serving_perf_metrics.py
Normal file
@ -0,0 +1,219 @@
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from test_common.http_utils import wait_for_endpoint_ready
|
||||
from test_common.perf_metrics_utils import (
|
||||
get_prometheus_metrics,
|
||||
get_timing_metrics,
|
||||
validate_timing_metrics,
|
||||
)
|
||||
|
||||
from tensorrt_llm._utils import get_free_ports
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteDisaggOpenAIServer, RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_ports():
|
||||
return get_free_ports(3)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disagg_port(test_ports: list[int]):
|
||||
return test_ports[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ctx_port(test_ports: list[int]):
|
||||
return test_ports[1]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gen_port(test_ports: list[int]):
|
||||
return test_ports[2]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disagg_cluster_config(disagg_port: int):
|
||||
return {
|
||||
"cluster_uri": f"http://localhost:{disagg_port}",
|
||||
"cluster_name": "",
|
||||
}
|
||||
|
||||
|
||||
def worker_config(model_name: str, disagg_cluster_config: dict):
|
||||
return {
|
||||
"model": model_name,
|
||||
"disagg_cluster": disagg_cluster_config,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "DEFAULT",
|
||||
},
|
||||
"kv_cache_config": {
|
||||
"free_gpu_memory_fraction": 0.2,
|
||||
"enable_block_reuse": False,
|
||||
},
|
||||
"disable_overlap_scheduler": True,
|
||||
"cuda_graph_config": None,
|
||||
"return_perf_metrics": True,
|
||||
"perf_metrics_max_requests": 1000,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workers(model_name: str, disagg_cluster_config: dict, ctx_port: int, gen_port: int):
|
||||
model_path = get_model_path(model_name)
|
||||
extra_config = worker_config(model_name, disagg_cluster_config)
|
||||
|
||||
def worker(server_role: str, port: int):
|
||||
return RemoteOpenAIServer(
|
||||
model_path,
|
||||
port=port,
|
||||
env=os.environ.copy(),
|
||||
cli_args=["--server_role", server_role],
|
||||
llmapi_launch=False,
|
||||
extra_config=extra_config,
|
||||
log_path=f"output_{server_role}.log",
|
||||
wait=False,
|
||||
)
|
||||
|
||||
with worker("context", ctx_port) as ctx_worker, worker("generation", gen_port) as gen_worker:
|
||||
yield ctx_worker, gen_worker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disagg_server(disagg_cluster_config: dict, workers, disagg_port: int):
|
||||
disagg_config = {
|
||||
"port": disagg_port,
|
||||
"disagg_cluster": disagg_cluster_config,
|
||||
"perf_metrics_max_requests": 1000,
|
||||
}
|
||||
with RemoteDisaggOpenAIServer(
|
||||
ctx_servers=[],
|
||||
gen_servers=[],
|
||||
port=disagg_config["port"],
|
||||
llmapi_launch=False,
|
||||
disagg_config=disagg_config,
|
||||
) as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(disagg_server: RemoteDisaggOpenAIServer):
|
||||
return disagg_server.get_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(disagg_server: RemoteDisaggOpenAIServer):
|
||||
return disagg_server.get_async_client()
|
||||
|
||||
|
||||
async def send_request(
|
||||
client: openai.AsyncOpenAI, stream: bool, repeat: int, max_token: int, model_name: str
|
||||
):
|
||||
for _ in range(repeat):
|
||||
prompt = "What is the result of 1+1? Answer in one word: "
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_token,
|
||||
temperature=0.0,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
output = []
|
||||
async for chunk in completion:
|
||||
output.append(chunk.choices[0].text)
|
||||
assert len(output) > 0
|
||||
message = "".join(output)
|
||||
else:
|
||||
assert completion.id is not None
|
||||
message = completion.choices[0].text
|
||||
assert message.startswith("2.")
|
||||
|
||||
|
||||
def check_historgram(metrics_dict: dict, count: int, range: tuple[float, float]):
|
||||
assert metrics_dict["count"] == count
|
||||
mean = metrics_dict["sum"] / metrics_dict["count"]
|
||||
assert mean > range[0] and mean < range[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.timeout(300)
|
||||
async def test_completion_metrics(
|
||||
async_client: openai.AsyncOpenAI,
|
||||
workers: Tuple[RemoteOpenAIServer, RemoteOpenAIServer],
|
||||
disagg_server: RemoteDisaggOpenAIServer,
|
||||
model_name: str,
|
||||
):
|
||||
assert len(workers) == 2
|
||||
for worker in workers:
|
||||
worker.wait_for_server(timeout=120)
|
||||
wait_for_endpoint_ready(disagg_server.url_root + "/health")
|
||||
|
||||
max_token = 10
|
||||
total_requests = 10
|
||||
await send_request(
|
||||
client=async_client,
|
||||
stream=True,
|
||||
repeat=total_requests,
|
||||
max_token=max_token,
|
||||
model_name=model_name,
|
||||
)
|
||||
timing_metrics = get_timing_metrics(disagg_server.url_root)
|
||||
validate_timing_metrics(timing_metrics, "test_completion_metrics")
|
||||
|
||||
metrics = get_prometheus_metrics(disagg_server.url_root)
|
||||
print(metrics)
|
||||
|
||||
for role in ["ctx", "gen"]:
|
||||
assert metrics[f"{role}_total_requests"] == total_requests
|
||||
assert metrics[f"{role}_completed_requests"] == total_requests
|
||||
assert metrics[f"{role}_error_requests"] == 0
|
||||
assert f"{role}_retry_requests" in metrics
|
||||
|
||||
check_historgram(metrics["gen_first_token_latency_seconds"], total_requests, (0.0, 0.3))
|
||||
check_historgram(metrics["gen_complete_latency_seconds"], total_requests, (0.0, 0.6))
|
||||
|
||||
assert metrics["total_requests"] == total_requests
|
||||
assert metrics["stream_requests"] == total_requests
|
||||
assert metrics["nonstream_requests"] == 0
|
||||
assert metrics["total_responses"] == total_requests
|
||||
assert metrics["validation_exceptions"] == 0
|
||||
assert metrics["http_exceptions"] == 0
|
||||
assert metrics["internal_errors"] == 0
|
||||
check_historgram(metrics["queue_latency_seconds"], total_requests, (0.0, 0.03))
|
||||
|
||||
# test non streaming part
|
||||
await send_request(
|
||||
client=async_client,
|
||||
stream=False,
|
||||
repeat=total_requests,
|
||||
max_token=max_token,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
metrics = get_prometheus_metrics(disagg_server.url_root)
|
||||
for role in ["ctx", "gen"]:
|
||||
assert metrics[f"{role}_total_requests"] == total_requests * 2
|
||||
assert metrics[f"{role}_completed_requests"] == total_requests * 2
|
||||
assert metrics[f"{role}_error_requests"] == 0
|
||||
assert f"{role}_retry_requests" in metrics
|
||||
|
||||
assert metrics["total_requests"] == total_requests * 2
|
||||
assert metrics["stream_requests"] == total_requests
|
||||
assert metrics["nonstream_requests"] == total_requests
|
||||
assert metrics["total_responses"] == total_requests * 2
|
||||
assert metrics["validation_exceptions"] == 0
|
||||
assert metrics["http_exceptions"] == 0
|
||||
assert metrics["internal_errors"] == 0
|
||||
|
||||
check_historgram(metrics["gen_complete_latency_seconds"], total_requests * 2, (0.0, 0.6))
|
||||
check_historgram(metrics["queue_latency_seconds"], total_requests * 2, (0.0, 0.03))
|
||||
@ -14,10 +14,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
@ -257,3 +259,29 @@ def expand_slurm_nodelist(nodelist_str):
|
||||
expanded_nodes.append(group)
|
||||
|
||||
return expanded_nodes
|
||||
|
||||
|
||||
def wait_for_endpoint_ready(url: str, timeout: int = 300, interval: int = 3):
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
time.sleep(interval)
|
||||
if requests.get(url).status_code == 200:
|
||||
print(f"endpoint {url} is ready")
|
||||
return
|
||||
except Exception as err:
|
||||
print(f"endpoint {url} is not ready, with exception: {err}")
|
||||
|
||||
|
||||
def wait_for_endpoint_down(url: str, timeout: int = 300):
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
if requests.get(url).status_code >= 100:
|
||||
print(
|
||||
f"endpoint {url} returned status code {requests.get(url).status_code}"
|
||||
)
|
||||
time.sleep(1)
|
||||
except Exception as err:
|
||||
print(f"endpoint {url} is down, with exception: {err}")
|
||||
return
|
||||
|
||||
@ -9,6 +9,7 @@ pythonpath =
|
||||
../../examples/auto_deploy
|
||||
../../examples/models/core
|
||||
../../examples
|
||||
../
|
||||
env =
|
||||
D:AUTO_DEPLOY_LOG_LEVEL=INFO
|
||||
markers =
|
||||
|
||||
Loading…
Reference in New Issue
Block a user