mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Core Metrics Implementation (#5785)
Signed-off-by: Ye Zhang <zhysishu@gmail.com> Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
This commit is contained in:
parent
97787883c3
commit
bcf5ec0c9a
@ -29,6 +29,8 @@ nvidia-modelopt[torch]~=0.33.0
|
||||
nvidia-nccl-cu12
|
||||
nvidia-cuda-nvrtc-cu12
|
||||
transformers==4.55.0
|
||||
prometheus_client
|
||||
prometheus_fastapi_instrumentator
|
||||
pydantic>=2.9.1
|
||||
pydantic-settings[yaml]
|
||||
omegaconf
|
||||
|
||||
@ -250,6 +250,12 @@ class LlmResult:
|
||||
self._result = tensorrt_llm.bindings.executor.deserialize_result(
|
||||
self._result)
|
||||
|
||||
def get_result(self):
|
||||
if tmp_res := tensorrt_llm.bindings.executor.deserialize_result(
|
||||
self._result):
|
||||
return tmp_res
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlmResponse:
|
||||
|
||||
@ -20,6 +20,7 @@ import linecache
|
||||
import math
|
||||
import os
|
||||
import struct
|
||||
import tempfile
|
||||
import trace
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
@ -1112,3 +1113,17 @@ def is_multi_device_enable():
|
||||
the number of devices
|
||||
"""
|
||||
return local_mpi_size() > 1
|
||||
|
||||
|
||||
def set_prometheus_multiproc_dir() -> object:
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266
|
||||
global prometheus_multiproc_dir
|
||||
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
|
||||
logger.info("User set PROMETHEUS_MULTIPROC_DIR detected.")
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
|
||||
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"])
|
||||
else:
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
logger.info(
|
||||
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
|
||||
|
||||
@ -3,7 +3,7 @@ import traceback
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
|
||||
Optional)
|
||||
Optional, Union)
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -18,7 +18,7 @@ from .utils import is_llm_response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .result import (DetokenizedGenerationResultBase, GenerationResult,
|
||||
GenerationResultBase)
|
||||
GenerationResultBase, ResponseWrapper)
|
||||
|
||||
__all__ = [
|
||||
"PostprocWorker",
|
||||
@ -57,7 +57,7 @@ class PostprocWorker:
|
||||
|
||||
@dataclass
|
||||
class Input:
|
||||
rsp: "tllm.Response"
|
||||
rsp: Union["tllm.Response", "ResponseWrapper"]
|
||||
|
||||
# The information necessary for creating a GenerationResult in the first Input for each request
|
||||
sampling_params: Optional[SamplingParams] = None
|
||||
@ -69,6 +69,7 @@ class PostprocWorker:
|
||||
res: Any
|
||||
is_final: bool
|
||||
error: str = ""
|
||||
metrics: Optional[dict[str, float]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -118,7 +119,9 @@ class PostprocWorker:
|
||||
streaming=inp.streaming,
|
||||
tokenizer=tokenizer)
|
||||
|
||||
async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
|
||||
async def _handle_input(
|
||||
self, input: Union["PostprocWorker.Input", "ResponseWrapper"]
|
||||
) -> [Any, Optional[dict[str, float]]]:
|
||||
''' Handle a single response from await_response worker. '''
|
||||
if input.rsp.result.context_logits is not None or \
|
||||
input.rsp.result.generation_logits is not None:
|
||||
@ -139,6 +142,7 @@ class PostprocWorker:
|
||||
record._handle_response(input.rsp) # inplace
|
||||
# Left the result_handler determine the final output dtype.
|
||||
# NOTE: This will change the CompletionOutput._postprocess_result
|
||||
metrics_dict = record.metrics_dict
|
||||
if postproc_params := record.postproc_params:
|
||||
result_handler, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
args.tokenizer = self._tokenizer
|
||||
@ -150,7 +154,7 @@ class PostprocWorker:
|
||||
|
||||
# TODO: Keep only the diff token_ids and text in streaming mode when
|
||||
# result_handler is not set
|
||||
return out
|
||||
return out, metrics_dict
|
||||
|
||||
async def _batched_put(self):
|
||||
''' Batched IPC send. '''
|
||||
@ -173,8 +177,12 @@ class PostprocWorker:
|
||||
client_id = inp.rsp.client_id
|
||||
is_final = inp.rsp.result.is_final if is_llm_response(
|
||||
inp.rsp) else True
|
||||
res = await self._handle_input(inp)
|
||||
batch.append(PostprocWorker.Output(client_id, res, is_final))
|
||||
res, metrics = await self._handle_input(inp)
|
||||
batch.append(
|
||||
PostprocWorker.Output(client_id=client_id,
|
||||
res=res,
|
||||
is_final=is_final,
|
||||
metrics=metrics))
|
||||
if is_final:
|
||||
self._records.pop(client_id)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from ..bindings import executor as tllm
|
||||
from ..disaggregated_params import DisaggregatedParams
|
||||
from ..llmapi.tracer import global_tracer
|
||||
from ..llmapi.utils import AsyncQueue
|
||||
from ..metrics import MetricNames, MetricsCollector, RequestEventTiming
|
||||
from ..sampling_params import LogprobParams, SamplingParams
|
||||
from .utils import ErrorResponse, has_event_loop, is_llm_response
|
||||
|
||||
@ -50,14 +51,18 @@ class LogProbsResult(NamedTuple):
|
||||
|
||||
|
||||
class ResponseWrapper:
|
||||
"""Wrapper of runtime response with optional outputs computed post runtime.
|
||||
"""
|
||||
1. Wrapper of runtime response with optional outputs computed post runtime.
|
||||
2. A workaround to pass around RequestPerfMetrics.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
response: Union["PostprocWorker.Output", tllm.Response],
|
||||
logprobs: Optional[LogProbsResult] = None):
|
||||
logprobs: Optional[LogProbsResult] = None,
|
||||
request_perf_metrics: Optional[dict[str, float]] = None):
|
||||
self._response = response
|
||||
self.logprobs = logprobs
|
||||
self.request_perf_metrics = request_perf_metrics
|
||||
|
||||
@property
|
||||
def _is_llm_response(self):
|
||||
@ -68,6 +73,14 @@ class ResponseWrapper:
|
||||
response = object.__getattribute__(self, '_response')
|
||||
return getattr(response, name)
|
||||
|
||||
def __getstate__(self):
|
||||
return (self._response, self.logprobs, self.request_perf_metrics)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._response = state[0]
|
||||
self.logprobs = state[1]
|
||||
self.request_perf_metrics = state[2]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class CompletionOutput:
|
||||
@ -146,6 +159,7 @@ class GenerationResultBase:
|
||||
self.disaggregated_params = None
|
||||
self.decoding_iter = 0
|
||||
self._done = False
|
||||
self.metrics_dict = {}
|
||||
|
||||
if has_event_loop():
|
||||
self.aqueue = AsyncQueue()
|
||||
@ -201,7 +215,9 @@ class GenerationResultBase:
|
||||
finish_reasons,
|
||||
response_tensors,
|
||||
sequence_index,
|
||||
logprobs_result=None):
|
||||
logprobs_result=None,
|
||||
req_perf_metrics_dict: Optional[dict[str,
|
||||
float]] = None):
|
||||
""" Handle a single sequence in the response. """
|
||||
|
||||
seq_idx = sequence_index
|
||||
@ -271,6 +287,7 @@ class GenerationResultBase:
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown finish reason: {finish_reasons[src_idx]}")
|
||||
self.record_stats(output, req_perf_metrics_dict)
|
||||
|
||||
@nvtx_range_debug("handle_response",
|
||||
color="red",
|
||||
@ -278,7 +295,9 @@ class GenerationResultBase:
|
||||
def _handle_response(self,
|
||||
response: Union["PostprocWorker.Output", tllm.Response,
|
||||
ResponseWrapper, ErrorResponse]):
|
||||
req_perf_metrics_dict = None
|
||||
if isinstance(response, ResponseWrapper):
|
||||
req_perf_metrics_dict = response.request_perf_metrics
|
||||
logprobs_result = response.logprobs
|
||||
response = response._response
|
||||
else:
|
||||
@ -291,6 +310,8 @@ class GenerationResultBase:
|
||||
self._outputs[0] = response.res
|
||||
else:
|
||||
self._outputs[0]._postprocess_result = response.res
|
||||
if response.metrics:
|
||||
self.metrics_dict = response.metrics
|
||||
|
||||
if response.error:
|
||||
if self._background_error_handler is not None and (
|
||||
@ -303,7 +324,8 @@ class GenerationResultBase:
|
||||
handler(response.error_msg)
|
||||
|
||||
response_result = response.result
|
||||
if hasattr(response_result, "_result"):
|
||||
if hasattr(response_result, "_result") and isinstance(
|
||||
response_result._result, bytes):
|
||||
response_result.deserialize()
|
||||
|
||||
self._done = response_result.is_final
|
||||
@ -322,11 +344,12 @@ class GenerationResultBase:
|
||||
if self.sampling_params.use_beam_search:
|
||||
for beam_idx, _ in enumerate(response_result.output_token_ids):
|
||||
self._handle_sequence(finish_reasons, response_result,
|
||||
beam_idx, logprobs_result)
|
||||
beam_idx, logprobs_result,
|
||||
req_perf_metrics_dict)
|
||||
else:
|
||||
self._handle_sequence(finish_reasons, response_result,
|
||||
response_result.sequence_index,
|
||||
logprobs_result)
|
||||
logprobs_result, req_perf_metrics_dict)
|
||||
|
||||
if response_result.context_logits is not None:
|
||||
self._context_logits = response_result.context_logits
|
||||
@ -342,6 +365,29 @@ class GenerationResultBase:
|
||||
else:
|
||||
raise ValueError(f"Unknown response type: {response}")
|
||||
|
||||
def record_stats(self,
|
||||
output: CompletionOutput,
|
||||
stats: Optional[dict[str, float]] = None) -> None:
|
||||
"""Record the stats of the generation result.
|
||||
|
||||
Args:
|
||||
output (CompletionOutput): The output of the generation result.
|
||||
stats (Optional[dict[str, float]]): The stats of the generation result. Defaults to None.
|
||||
"""
|
||||
if not stats:
|
||||
return
|
||||
metrics_stats = {}
|
||||
if output.finish_reason:
|
||||
metrics_stats.update({
|
||||
MetricsCollector.labelname_finish_reason:
|
||||
output.finish_reason
|
||||
})
|
||||
processed_metrics_stat = _process_req_perf_metrics(
|
||||
stats, len(output.token_ids), self.sampling_params.n > 1)
|
||||
if processed_metrics_stat:
|
||||
metrics_stats.update(processed_metrics_stat)
|
||||
self.metrics_dict = metrics_stats
|
||||
|
||||
|
||||
class DetokenizedGenerationResultBase(GenerationResultBase):
|
||||
''' The base class for the generation result with detokenization support. '''
|
||||
@ -688,3 +734,30 @@ def compute_logprobs(
|
||||
|
||||
return LogProbsResult(prompt=prompt_logprobs,
|
||||
generation=generation_logprobs)
|
||||
|
||||
|
||||
def _process_req_perf_metrics(
|
||||
req_perf_metrics_dict: Optional[dict[str, float]],
|
||||
output_length: int,
|
||||
is_multiple_response: bool = False) -> dict[MetricNames, float]:
|
||||
stat = {}
|
||||
if not req_perf_metrics_dict:
|
||||
return stat
|
||||
ttft = req_perf_metrics_dict.get(RequestEventTiming.FIRST_TOKEN_TIME, 0) - \
|
||||
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
|
||||
e2e = req_perf_metrics_dict.get(RequestEventTiming.LAST_TOKEN_TIME, 0) - \
|
||||
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
|
||||
request_queue_time = req_perf_metrics_dict.get(RequestEventTiming.FIRST_SCHEDULED_TIME, 0) - \
|
||||
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
|
||||
stat = {
|
||||
MetricNames.TTFT: ttft,
|
||||
MetricNames.E2E: e2e,
|
||||
MetricNames.REQUEST_QUEUE_TIME: request_queue_time
|
||||
}
|
||||
if output_length > 1 and not is_multiple_response:
|
||||
tpot = (req_perf_metrics_dict.get(
|
||||
RequestEventTiming.LAST_TOKEN_TIME, 0) - req_perf_metrics_dict.get(
|
||||
RequestEventTiming.FIRST_TOKEN_TIME, 0)) / (output_length - 1)
|
||||
stat.update({MetricNames.TPOT: tpot})
|
||||
stat = dict(filter(lambda item: item[1] > 0, stat.items()))
|
||||
return stat
|
||||
|
||||
@ -25,6 +25,7 @@ from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
|
||||
clear_sched_affinity, print_colored_debug,
|
||||
print_traceback_on_error)
|
||||
from ..lora_manager import LoraConfig, LoraManager
|
||||
from ..metrics import RequestEventTiming
|
||||
from ..prompt_adapter_manager import PromptAdapterManager
|
||||
from ..runtime import ModelConfig
|
||||
from ..runtime.model_runner import _engine_config_to_model_config
|
||||
@ -899,10 +900,8 @@ class AwaitResponseHelper:
|
||||
assert response is not None
|
||||
queue = self.worker.return_queue(response.client_id)
|
||||
|
||||
logprobs_result = _get_logprobs(self.worker, response,
|
||||
response = _maybe_wrap_response(self.worker, response,
|
||||
self.worker._is_pytorch_backend)
|
||||
if logprobs_result:
|
||||
response = ResponseWrapper(response, logprobs_result)
|
||||
|
||||
# For AsyncQueue.sync_q, we will batch the events to avoid too many
|
||||
# event notifications, thus put without wait here.
|
||||
@ -940,10 +939,8 @@ class AwaitResponseHelper:
|
||||
response = ErrorResponse(response.client_id, response.error_msg,
|
||||
response.request_id)
|
||||
else:
|
||||
logprobs_result = _get_logprobs(self.worker, response,
|
||||
response = _maybe_wrap_response(self.worker, response,
|
||||
self.worker._is_pytorch_backend)
|
||||
if logprobs_result:
|
||||
response = ResponseWrapper(response, logprobs_result)
|
||||
|
||||
_send_rsp(self.worker,
|
||||
response,
|
||||
@ -1051,3 +1048,41 @@ def _send_rsp(
|
||||
worker._pop_result(response.client_id)
|
||||
else:
|
||||
raise ValueError(f"Unknown response type: {response}")
|
||||
|
||||
|
||||
def _get_metrics_dict(
|
||||
response: tllm.Response) -> dict[RequestEventTiming, float]:
|
||||
req_perf_metrics, metrics_dict = None, {}
|
||||
res = response.result
|
||||
if res:
|
||||
if hasattr(res, '_result'):
|
||||
if result := res.get_result():
|
||||
req_perf_metrics = result.request_perf_metrics
|
||||
else:
|
||||
req_perf_metrics = res.request_perf_metrics
|
||||
if req_perf_metrics and req_perf_metrics.timing_metrics:
|
||||
metrics_dict = {
|
||||
RequestEventTiming.ARRIVAL_TIME:
|
||||
req_perf_metrics.timing_metrics.arrival_time.total_seconds(),
|
||||
RequestEventTiming.FIRST_TOKEN_TIME:
|
||||
req_perf_metrics.timing_metrics.first_token_time.total_seconds(
|
||||
),
|
||||
RequestEventTiming.FIRST_SCHEDULED_TIME:
|
||||
req_perf_metrics.timing_metrics.first_scheduled_time.
|
||||
total_seconds(),
|
||||
RequestEventTiming.LAST_TOKEN_TIME:
|
||||
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
|
||||
}
|
||||
return metrics_dict
|
||||
|
||||
|
||||
def _maybe_wrap_response(
|
||||
worker,
|
||||
response: tllm.Response,
|
||||
is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]:
|
||||
|
||||
logprobs_result = _get_logprobs(worker, response, is_pytorch_backend)
|
||||
req_perf_metrics = _get_metrics_dict(response)
|
||||
if logprobs_result or req_perf_metrics:
|
||||
response = ResponseWrapper(response, logprobs_result, req_perf_metrics)
|
||||
return response
|
||||
|
||||
@ -548,7 +548,7 @@ class BaseLLM:
|
||||
if sampling_params._stream_interval is None:
|
||||
sampling_params._stream_interval = getattr(self.args,
|
||||
"stream_interval", 1)
|
||||
|
||||
sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics
|
||||
return sampling_params
|
||||
|
||||
def _check_arguments(self, prompt_len: int, query_len: int,
|
||||
|
||||
@ -1311,6 +1311,10 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
status="deprecated",
|
||||
)
|
||||
|
||||
return_perf_metrics: bool = Field(default=False,
|
||||
description="Return perf metrics.",
|
||||
status="prototype")
|
||||
|
||||
_parallel_config: Optional[object] = PrivateAttr(default=None)
|
||||
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
|
||||
_speculative_model: Optional[str] = PrivateAttr(default=None)
|
||||
|
||||
4
tensorrt_llm/metrics/__init__.py
Normal file
4
tensorrt_llm/metrics/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .collector import *
|
||||
from .enums import *
|
||||
|
||||
__all__ = ["MetricsCollector", "MetricNames", "RequestEventTiming"]
|
||||
105
tensorrt_llm/metrics/collector.py
Normal file
105
tensorrt_llm/metrics/collector.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Utilities for Prometheus Metrics Collection."""
|
||||
|
||||
import time
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from .enums import MetricNames
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0rc1/vllm/engine/metrics.py#L30
|
||||
class MetricsCollector:
|
||||
labelname_finish_reason = "finished_reason"
|
||||
|
||||
def __init__(self, labels: Dict[str, str]) -> None:
|
||||
from prometheus_client import Counter, Histogram
|
||||
self.last_log_time = time.time()
|
||||
self.labels = labels
|
||||
|
||||
self.finish_reason_label = {
|
||||
MetricsCollector.labelname_finish_reason: "unknown"
|
||||
}
|
||||
self.labels_with_finished_reason = {
|
||||
**self.labels,
|
||||
**self.finish_reason_label
|
||||
}
|
||||
|
||||
self.counter_request_success = Counter(
|
||||
name="request_success_total",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=self.labels_with_finished_reason.keys())
|
||||
|
||||
self.histogram_e2e_time_request = Histogram(
|
||||
name="e2e_request_latency_seconds",
|
||||
documentation="Histogram of end to end request latency in seconds.",
|
||||
buckets=[
|
||||
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
|
||||
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
|
||||
],
|
||||
labelnames=self.labels.keys())
|
||||
|
||||
self.histogram_time_to_first_token = Histogram(
|
||||
name="time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
buckets=[
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0,
|
||||
2560.0
|
||||
],
|
||||
labelnames=self.labels.keys())
|
||||
|
||||
self.histogram_time_per_output_token = Histogram(
|
||||
name="time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
buckets=[
|
||||
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
|
||||
1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
|
||||
],
|
||||
labelnames=self.labels.keys())
|
||||
|
||||
self.histogram_queue_time_request = Histogram(
|
||||
name="request_queue_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in WAITING phase for request.",
|
||||
buckets=[
|
||||
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
|
||||
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
|
||||
],
|
||||
labelnames=self.labels.keys())
|
||||
|
||||
def _label_merge(self, labels: Dict[str, str]) -> Dict[str, str]:
|
||||
if labels is None or len(labels) == 0:
|
||||
return self.labels
|
||||
return {**self.labels, **labels}
|
||||
|
||||
def _log_counter(self, counter, labels: Dict[str, str],
|
||||
data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to counter.
|
||||
counter.labels(**self._label_merge(labels)).inc(data)
|
||||
|
||||
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to histogram.
|
||||
histogram.labels(**self.labels).observe(data)
|
||||
|
||||
def log_request_success(self, data: Union[int, float],
|
||||
labels: Dict[str, str]) -> None:
|
||||
self._log_counter(self.counter_request_success, labels, data)
|
||||
self.last_log_time = time.time()
|
||||
|
||||
def log_histogram(self, data: Optional[dict[str, float]]) -> None:
|
||||
if e2e := data.get(MetricNames.E2E, 0):
|
||||
self._log_histogram(self.histogram_e2e_time_request, e2e)
|
||||
if ttft := data.get(MetricNames.TTFT, 0):
|
||||
self._log_histogram(self.histogram_time_to_first_token, ttft)
|
||||
if tpot := data.get(MetricNames.TPOT, 0):
|
||||
self._log_histogram(self.histogram_time_per_output_token, tpot)
|
||||
if request_queue_time := data.get(MetricNames.REQUEST_QUEUE_TIME, 0):
|
||||
self._log_histogram(self.histogram_queue_time_request,
|
||||
request_queue_time)
|
||||
self.last_log_time = time.time()
|
||||
|
||||
def log_metrics_dict(self, metrics_dict: dict[str, float]) -> None:
|
||||
if finish_reason := metrics_dict.get(
|
||||
MetricsCollector.labelname_finish_reason):
|
||||
self.log_request_success(
|
||||
1, {MetricsCollector.labelname_finish_reason: finish_reason})
|
||||
self.log_histogram(metrics_dict)
|
||||
15
tensorrt_llm/metrics/enums.py
Normal file
15
tensorrt_llm/metrics/enums.py
Normal file
@ -0,0 +1,15 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MetricNames(Enum):
|
||||
TTFT = "ttft"
|
||||
TPOT = "tpot"
|
||||
E2E = "e2e"
|
||||
REQUEST_QUEUE_TIME = "request_queue_time"
|
||||
|
||||
|
||||
class RequestEventTiming(Enum):
|
||||
ARRIVAL_TIME = "arrival_time"
|
||||
FIRST_TOKEN_TIME = "first_token_time" # nosec: B105
|
||||
FIRST_SCHEDULED_TIME = "first_scheduled_time"
|
||||
LAST_TOKEN_TIME = "last_token_time" # nosec: B105
|
||||
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import traceback
|
||||
from contextlib import asynccontextmanager
|
||||
@ -13,6 +14,7 @@ import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.routing import Mount
|
||||
from transformers import AutoConfig, AutoProcessor
|
||||
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
@ -25,6 +27,7 @@ from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
|
||||
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole
|
||||
from tensorrt_llm.llmapi.llm import RequestOutput
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.metrics.collector import MetricsCollector
|
||||
from tensorrt_llm.serve.chat_utils import (check_multiple_response,
|
||||
parse_chat_messages_coroutines)
|
||||
from tensorrt_llm.serve.metadata_server import create_metadata_server
|
||||
@ -42,7 +45,7 @@ from tensorrt_llm.serve.postprocess_handlers import (
|
||||
completion_stream_post_processor)
|
||||
from tensorrt_llm.version import __version__ as VERSION
|
||||
|
||||
from .._utils import nvtx_mark
|
||||
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
|
||||
|
||||
# yapf: enale
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
@ -78,6 +81,13 @@ class OpenAIServer:
|
||||
self.model = model_dir.name
|
||||
else:
|
||||
self.model = model
|
||||
self.metrics_collector = None
|
||||
if self.llm.args.return_perf_metrics:
|
||||
set_prometheus_multiproc_dir()
|
||||
self.metrics_collector = MetricsCollector({
|
||||
"model_name": "undefined",
|
||||
"engine_type": "undefined"
|
||||
})
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@ -151,6 +161,32 @@ class OpenAIServer:
|
||||
self.app.add_api_route("/v1/chat/completions",
|
||||
self.openai_chat,
|
||||
methods=["POST"])
|
||||
if self.llm.args.return_perf_metrics:
|
||||
# register /prometheus/metrics
|
||||
self.mount_metrics()
|
||||
|
||||
def mount_metrics(self):
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||
multiprocess)
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
Instrumentator(
|
||||
should_group_status_codes=False,
|
||||
should_respect_env_var=True,
|
||||
excluded_handlers=[
|
||||
".*"
|
||||
],
|
||||
registry=registry,
|
||||
).add().instrument(self.app).expose(self.app)
|
||||
metrics_app = make_asgi_app(registry=registry)
|
||||
metrics_route = Mount("/prometheus/metrics", metrics_app)
|
||||
metrics_route.path_regex = re.compile("^/prometheus/metrics(?P<path>.*)$")
|
||||
self.app.routes.append(metrics_route)
|
||||
|
||||
async def health(self) -> Response:
|
||||
return Response(status_code=200)
|
||||
@ -228,6 +264,8 @@ class OpenAIServer:
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
async for res in promise:
|
||||
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
|
||||
if res.finished and self.metrics_collector:
|
||||
self.metrics_collector.log_metrics_dict(res.metrics_dict)
|
||||
for pp_res in pp_results:
|
||||
yield pp_res
|
||||
yield "data: [DONE]\n\n"
|
||||
@ -245,6 +283,8 @@ class OpenAIServer:
|
||||
# Add prompt_tokens_ids to the response
|
||||
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
|
||||
chat_response.prompt_token_ids = promise.prompt_token_ids
|
||||
if promise.finished and self.metrics_collector:
|
||||
self.metrics_collector.log_metrics_dict(promise.metrics_dict)
|
||||
return chat_response
|
||||
|
||||
try:
|
||||
@ -337,6 +377,8 @@ class OpenAIServer:
|
||||
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
|
||||
# Include prompt token ids for context-only requests
|
||||
pp_result.prompt_token_ids = response.prompt_token_ids
|
||||
if response.finished and self.metrics_collector:
|
||||
self.metrics_collector.log_metrics_dict(response.metrics_dict)
|
||||
return pp_result
|
||||
|
||||
def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse:
|
||||
@ -372,6 +414,8 @@ class OpenAIServer:
|
||||
pp_result = post_processor(output, args)
|
||||
else:
|
||||
pp_result = output.outputs[0]._postprocess_result
|
||||
if output.finished and self.metrics_collector:
|
||||
self.metrics_collector.log_metrics_dict(output.metrics_dict)
|
||||
for pp_res in pp_result:
|
||||
yield pp_res
|
||||
|
||||
|
||||
@ -1497,6 +1497,13 @@ def test_openai_chat_with_logit_bias(llm_root, llm_venv, sampler: str):
|
||||
])
|
||||
|
||||
|
||||
def test_openai_prometheus(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd(
|
||||
["-m", "pytest",
|
||||
str(test_root / "_test_openai_prometheus.py")])
|
||||
|
||||
|
||||
def test_openai_lora(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")])
|
||||
|
||||
@ -25,6 +25,7 @@ l0_a10:
|
||||
- test_e2e.py::test_openai_chat_structural_tag_example
|
||||
- test_e2e.py::test_openai_chat_json_example
|
||||
- test_e2e.py::test_openai_chat_multimodal_example
|
||||
- test_e2e.py::test_openai_prometheus
|
||||
- test_e2e.py::test_openai_lora
|
||||
- test_e2e.py::test_trtllm_serve_multimodal_example
|
||||
- test_e2e.py::test_trtllm_serve_lora_example
|
||||
|
||||
@ -27,6 +27,10 @@ methods:
|
||||
annotation: Optional[int]
|
||||
default: null
|
||||
status: prototype
|
||||
return_perf_metrics:
|
||||
annotation: bool
|
||||
default: False
|
||||
status: prototype
|
||||
# Bindings and mirrored configs
|
||||
peft_cache_config:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.PeftCacheConfig]
|
||||
|
||||
@ -11,4 +11,13 @@ methods:
|
||||
clear_logprob_params:
|
||||
parameters: {}
|
||||
return_annotation: None
|
||||
record_stats:
|
||||
parameters:
|
||||
output:
|
||||
annotation: tensorrt_llm.executor.result.CompletionOutput
|
||||
default: inspect._empty
|
||||
stats:
|
||||
annotation: Optional[dict[str, float]]
|
||||
default: None
|
||||
return_annotation: None
|
||||
properties: {}
|
||||
|
||||
67
tests/unittest/llmapi/apps/_test_openai_prometheus.py
Normal file
67
tests/unittest/llmapi/apps/_test_openai_prometheus.py
Normal file
@ -0,0 +1,67 @@
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from urllib.request import urlopen
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from ..test_llm import get_model_path
|
||||
from .openai_server import RemoteOpenAIServer
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"])
|
||||
def model_name():
|
||||
return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def temp_extra_llm_api_options_file(request):
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
|
||||
try:
|
||||
extra_llm_api_options_dict = {"return_perf_metrics": True}
|
||||
|
||||
with open(temp_file_path, 'w') as f:
|
||||
yaml.dump(extra_llm_api_options_dict, f)
|
||||
|
||||
yield temp_file_path
|
||||
finally:
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_name: str,
|
||||
temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer:
|
||||
model_path = get_model_path(model_name)
|
||||
args = ["--backend", "pytorch", "--tp_size", "1"]
|
||||
args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file])
|
||||
logger.info(f"Starting server, model: {model_name}, args: {args}")
|
||||
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||
yield remote_server
|
||||
logger.info("Tests completed, shutting down server")
|
||||
|
||||
|
||||
def test_metrics_endpoint(server: RemoteOpenAIServer):
|
||||
|
||||
client = server.get_client()
|
||||
client.completions.create(
|
||||
model="Server",
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=25,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = urlopen(f'{server.url_root}/prometheus/metrics')
|
||||
assert response.status is 200
|
||||
|
||||
data = response.read().decode("utf-8")
|
||||
assert "request_success_total" in data
|
||||
assert "e2e_request_latency_seconds" in data
|
||||
assert "time_to_first_token_seconds" in data
|
||||
assert "request_queue_time_seconds" in data
|
||||
@ -6,6 +6,7 @@ from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
|
||||
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
|
||||
from tensorrt_llm.metrics import MetricNames
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
# isort: off
|
||||
@ -195,6 +196,27 @@ def test_llm_perf_metrics():
|
||||
assert perf_metrics.last_iter == perf_metrics.iter
|
||||
|
||||
|
||||
def test_llm_prometheus():
|
||||
test_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=10, temperature=0.8, top_p=0.95)
|
||||
llm = LLM(model=llama_model_path,
|
||||
return_perf_metrics=True,
|
||||
kv_cache_config=global_kvcache_config)
|
||||
for test_prompt in test_prompts:
|
||||
request_output = llm.generate(test_prompt, sampling_params)
|
||||
assert request_output.metrics_dict is not None
|
||||
assert MetricNames.REQUEST_QUEUE_TIME in request_output.metrics_dict
|
||||
assert MetricNames.TPOT in request_output.metrics_dict
|
||||
assert MetricNames.TTFT in request_output.metrics_dict
|
||||
assert MetricNames.E2E in request_output.metrics_dict
|
||||
assert request_output.outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_llm_with_postprocess_parallel_and_result_handler(streaming):
|
||||
run_llm_with_postprocess_parallel_and_result_handler(streaming,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user