mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
240 lines
8.5 KiB
Python
240 lines
8.5 KiB
Python
# Copyright (c) 2025, NVIDIA CORPORATION.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import asyncio
|
|
from collections import defaultdict, deque
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
|
|
from tensorrt_llm.llmapi.disagg_utils import ServerRole
|
|
|
|
COUNTER_METRICS = [
|
|
("total_requests", "Total number of requests"),
|
|
("error_requests", "Total number of error requests"),
|
|
("retry_requests", "Total number of retry requests"),
|
|
("completed_requests", "Total number of completed requests"),
|
|
]
|
|
# fmt: off
|
|
LONG_TIME_BUCKETS = [
|
|
0.1, 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0,
|
|
60.0, 120.0, 240.0, 480.0, 960.0, 1920.0,
|
|
]
|
|
SHORT_TIME_BUCKETS = [
|
|
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0,
|
|
7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, 2560.0,
|
|
]
|
|
# fmt: on
|
|
HISTOGRAM_METRICS = [
|
|
(
|
|
"first_token_latency_seconds",
|
|
"Histogram of latency from first token to completion in seconds",
|
|
SHORT_TIME_BUCKETS,
|
|
),
|
|
(
|
|
"complete_latency_seconds",
|
|
"Histogram of latency from request arrival to last token in seconds",
|
|
LONG_TIME_BUCKETS,
|
|
),
|
|
(
|
|
"per_token_latency_seconds",
|
|
"Histogram of latency from request arrival to completion in seconds",
|
|
SHORT_TIME_BUCKETS,
|
|
),
|
|
]
|
|
|
|
MetricsTypeLiteral = Literal["counter", "histogram"]
|
|
|
|
|
|
@dataclass
|
|
class MetricsDefinition:
|
|
name: str
|
|
description: str
|
|
type: MetricsTypeLiteral
|
|
buckets: Optional[List[float]] = None
|
|
|
|
|
|
CLIENT_METRICS_DEFINITIONS = [
|
|
MetricsDefinition("total_requests", "Total number of requests", "counter"),
|
|
MetricsDefinition("error_requests", "Total number of error requests", "counter"),
|
|
MetricsDefinition("retry_requests", "Total number of retry requests", "counter"),
|
|
MetricsDefinition("completed_requests", "Total number of completed requests", "counter"),
|
|
MetricsDefinition(
|
|
"first_token_latency_seconds",
|
|
"Histogram of latency from first token to completion in seconds",
|
|
"histogram",
|
|
SHORT_TIME_BUCKETS,
|
|
),
|
|
MetricsDefinition(
|
|
"complete_latency_seconds",
|
|
"Histogram of latency from request arrival to last token in seconds",
|
|
"histogram",
|
|
LONG_TIME_BUCKETS,
|
|
),
|
|
MetricsDefinition(
|
|
"per_token_latency_seconds",
|
|
"Histogram of latency from request arrival to completion in seconds",
|
|
"histogram",
|
|
SHORT_TIME_BUCKETS,
|
|
),
|
|
]
|
|
|
|
ROLE_TO_CLIENT_TYPE = {
|
|
ServerRole.CONTEXT: "ctx",
|
|
ServerRole.GENERATION: "gen",
|
|
ServerRole.MM_ENCODER: "mme",
|
|
}
|
|
|
|
|
|
def instance_metric(definition: MetricsDefinition, role: Optional[ServerRole] = None):
|
|
# import lazily to avoid breaking `set_prometheus_multiproc_dir`
|
|
from prometheus_client import Counter, Histogram
|
|
|
|
name = (
|
|
f"{ROLE_TO_CLIENT_TYPE[role]}_{definition.name}"
|
|
if role in ROLE_TO_CLIENT_TYPE
|
|
else definition.name
|
|
)
|
|
if definition.type == "counter":
|
|
return Counter(name, definition.description)
|
|
elif definition.type == "histogram":
|
|
return Histogram(name, definition.description, buckets=definition.buckets)
|
|
else:
|
|
raise ValueError(f"Invalid metric type: {definition.type}")
|
|
|
|
|
|
class ClientMetricsCollector:
|
|
def __init__(self, role: ServerRole):
|
|
self._role = role
|
|
self._metrics = {
|
|
definition.name: instance_metric(definition, role)
|
|
for definition in CLIENT_METRICS_DEFINITIONS
|
|
}
|
|
|
|
def __getattr__(
|
|
self, key: str
|
|
): # no return type hint to not import prometheus_client at module level
|
|
return self._metrics[key]
|
|
|
|
|
|
SERVER_METRICS_DEFINITIONS = [
|
|
MetricsDefinition("total_requests", "Total number of requests", "counter"),
|
|
MetricsDefinition("stream_requests", "Total number of stream requests", "counter"),
|
|
MetricsDefinition("nonstream_requests", "Total number of non-stream requests", "counter"),
|
|
MetricsDefinition("validation_exceptions", "Total number of validation exceptions", "counter"),
|
|
MetricsDefinition("http_exceptions", "Total number of HTTP exceptions", "counter"),
|
|
MetricsDefinition("internal_errors", "Total number of internal errors", "counter"),
|
|
MetricsDefinition("total_responses", "Total number of responses", "counter"),
|
|
MetricsDefinition(
|
|
"queue_latency_seconds",
|
|
"Histogram of latency from request arrival to being processed in seconds",
|
|
"histogram",
|
|
SHORT_TIME_BUCKETS,
|
|
),
|
|
]
|
|
|
|
|
|
class DisaggPerfMetricsCollector:
|
|
def __init__(self, max_requests: int):
|
|
self._max_requests = max_requests
|
|
self._request_meteics = deque(maxlen=max_requests)
|
|
self._server_metrics = defaultdict(dict)
|
|
self._lock = asyncio.Lock()
|
|
self._clients = []
|
|
self._metrics = {
|
|
definition.name: instance_metric(definition)
|
|
for definition in SERVER_METRICS_DEFINITIONS
|
|
}
|
|
|
|
def add_client(self, client):
|
|
self._clients.append(client)
|
|
|
|
def __getattr__(self, key: str):
|
|
return self._metrics[key]
|
|
|
|
async def add_per_request_metrics(
|
|
self,
|
|
ctx_server: str,
|
|
gen_server: str,
|
|
ctx_request_id: int,
|
|
server_arrival_time: float,
|
|
server_first_token_time: float,
|
|
):
|
|
async with self._lock:
|
|
self._request_meteics.append(
|
|
(
|
|
ctx_server,
|
|
gen_server,
|
|
ctx_request_id,
|
|
server_arrival_time,
|
|
server_first_token_time,
|
|
)
|
|
)
|
|
|
|
async def get_perf_metrics(self) -> List[Dict[str, Any]]:
|
|
perf_metrics = {}
|
|
for client in self._clients:
|
|
metrics_dict = await client.collect_metrics()
|
|
perf_metrics.update(metrics_dict)
|
|
|
|
return_metrics = []
|
|
async with self._lock:
|
|
for server, metrics_data in perf_metrics.items():
|
|
server_metrics = self._server_metrics[server]
|
|
# avoid metrics map inflation by limiting the number of requests to add
|
|
available_req_num = min(
|
|
max(0, self._max_requests - len(server_metrics)), len(metrics_data)
|
|
)
|
|
req_metrics_map = {
|
|
req_metrics["ctx_request_id"]: req_metrics
|
|
for req_metrics in metrics_data[:available_req_num]
|
|
if "ctx_request_id" in req_metrics
|
|
}
|
|
server_metrics.update(req_metrics_map)
|
|
|
|
remain_keys = []
|
|
for (
|
|
ctx_server,
|
|
gen_server,
|
|
ctx_request_id,
|
|
server_arrival_time,
|
|
server_first_token_time,
|
|
) in self._request_meteics:
|
|
gen_perf_metrics = self._server_metrics[gen_server].pop(ctx_request_id, None)
|
|
if gen_perf_metrics is None:
|
|
# generation not finished
|
|
remain_keys.append(
|
|
(
|
|
ctx_server,
|
|
gen_server,
|
|
ctx_request_id,
|
|
server_arrival_time,
|
|
server_first_token_time,
|
|
)
|
|
)
|
|
continue
|
|
ctx_perf_metrics = self._server_metrics[ctx_server].pop(ctx_request_id, None)
|
|
# TODO: strip the keys for less repeating and use table style response
|
|
return_metrics.append(
|
|
{
|
|
"ctx_server": ctx_server,
|
|
"gen_server": gen_server,
|
|
"disagg_server_arrival_time": server_arrival_time,
|
|
"disagg_server_first_token_time": server_first_token_time,
|
|
"ctx_perf_metrics": ctx_perf_metrics,
|
|
"gen_perf_metrics": gen_perf_metrics,
|
|
}
|
|
)
|
|
self._request_meteics = deque(remain_keys, maxlen=self._max_requests)
|
|
return return_metrics
|