[None][feat] Core Metrics Implementation (#5785)

Signed-off-by: Ye Zhang <zhysishu@gmail.com>
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
This commit is contained in:
Ye Zhang 2025-08-09 14:48:53 +08:00 committed by GitHub
parent 97787883c3
commit bcf5ec0c9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 442 additions and 21 deletions

View File

@ -29,6 +29,8 @@ nvidia-modelopt[torch]~=0.33.0
nvidia-nccl-cu12 nvidia-nccl-cu12
nvidia-cuda-nvrtc-cu12 nvidia-cuda-nvrtc-cu12
transformers==4.55.0 transformers==4.55.0
prometheus_client
prometheus_fastapi_instrumentator
pydantic>=2.9.1 pydantic>=2.9.1
pydantic-settings[yaml] pydantic-settings[yaml]
omegaconf omegaconf

View File

@ -250,6 +250,12 @@ class LlmResult:
self._result = tensorrt_llm.bindings.executor.deserialize_result( self._result = tensorrt_llm.bindings.executor.deserialize_result(
self._result) self._result)
def get_result(self):
if tmp_res := tensorrt_llm.bindings.executor.deserialize_result(
self._result):
return tmp_res
return None
@dataclass @dataclass
class LlmResponse: class LlmResponse:

View File

@ -20,6 +20,7 @@ import linecache
import math import math
import os import os
import struct import struct
import tempfile
import trace import trace
import weakref import weakref
from contextlib import contextmanager from contextlib import contextmanager
@ -1112,3 +1113,17 @@ def is_multi_device_enable():
the number of devices the number of devices
""" """
return local_mpi_size() > 1 return local_mpi_size() > 1
def set_prometheus_multiproc_dir() -> object:
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266
global prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
logger.info("User set PROMETHEUS_MULTIPROC_DIR detected.")
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"])
else:
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
logger.info(
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")

View File

@ -3,7 +3,7 @@ import traceback
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional) Optional, Union)
import zmq import zmq
import zmq.asyncio import zmq.asyncio
@ -18,7 +18,7 @@ from .utils import is_llm_response
if TYPE_CHECKING: if TYPE_CHECKING:
from .result import (DetokenizedGenerationResultBase, GenerationResult, from .result import (DetokenizedGenerationResultBase, GenerationResult,
GenerationResultBase) GenerationResultBase, ResponseWrapper)
__all__ = [ __all__ = [
"PostprocWorker", "PostprocWorker",
@ -57,7 +57,7 @@ class PostprocWorker:
@dataclass @dataclass
class Input: class Input:
rsp: "tllm.Response" rsp: Union["tllm.Response", "ResponseWrapper"]
# The information necessary for creating a GenerationResult in the first Input for each request # The information necessary for creating a GenerationResult in the first Input for each request
sampling_params: Optional[SamplingParams] = None sampling_params: Optional[SamplingParams] = None
@ -69,6 +69,7 @@ class PostprocWorker:
res: Any res: Any
is_final: bool is_final: bool
error: str = "" error: str = ""
metrics: Optional[dict[str, float]] = None
def __init__( def __init__(
self, self,
@ -118,7 +119,9 @@ class PostprocWorker:
streaming=inp.streaming, streaming=inp.streaming,
tokenizer=tokenizer) tokenizer=tokenizer)
async def _handle_input(self, input: "PostprocWorker.Input") -> Any: async def _handle_input(
self, input: Union["PostprocWorker.Input", "ResponseWrapper"]
) -> [Any, Optional[dict[str, float]]]:
''' Handle a single response from await_response worker. ''' ''' Handle a single response from await_response worker. '''
if input.rsp.result.context_logits is not None or \ if input.rsp.result.context_logits is not None or \
input.rsp.result.generation_logits is not None: input.rsp.result.generation_logits is not None:
@ -139,6 +142,7 @@ class PostprocWorker:
record._handle_response(input.rsp) # inplace record._handle_response(input.rsp) # inplace
# Left the result_handler determine the final output dtype. # Left the result_handler determine the final output dtype.
# NOTE: This will change the CompletionOutput._postprocess_result # NOTE: This will change the CompletionOutput._postprocess_result
metrics_dict = record.metrics_dict
if postproc_params := record.postproc_params: if postproc_params := record.postproc_params:
result_handler, args = postproc_params.post_processor, postproc_params.postproc_args result_handler, args = postproc_params.post_processor, postproc_params.postproc_args
args.tokenizer = self._tokenizer args.tokenizer = self._tokenizer
@ -150,7 +154,7 @@ class PostprocWorker:
# TODO: Keep only the diff token_ids and text in streaming mode when # TODO: Keep only the diff token_ids and text in streaming mode when
# result_handler is not set # result_handler is not set
return out return out, metrics_dict
async def _batched_put(self): async def _batched_put(self):
''' Batched IPC send. ''' ''' Batched IPC send. '''
@ -173,8 +177,12 @@ class PostprocWorker:
client_id = inp.rsp.client_id client_id = inp.rsp.client_id
is_final = inp.rsp.result.is_final if is_llm_response( is_final = inp.rsp.result.is_final if is_llm_response(
inp.rsp) else True inp.rsp) else True
res = await self._handle_input(inp) res, metrics = await self._handle_input(inp)
batch.append(PostprocWorker.Output(client_id, res, is_final)) batch.append(
PostprocWorker.Output(client_id=client_id,
res=res,
is_final=is_final,
metrics=metrics))
if is_final: if is_final:
self._records.pop(client_id) self._records.pop(client_id)

View File

@ -15,6 +15,7 @@ from ..bindings import executor as tllm
from ..disaggregated_params import DisaggregatedParams from ..disaggregated_params import DisaggregatedParams
from ..llmapi.tracer import global_tracer from ..llmapi.tracer import global_tracer
from ..llmapi.utils import AsyncQueue from ..llmapi.utils import AsyncQueue
from ..metrics import MetricNames, MetricsCollector, RequestEventTiming
from ..sampling_params import LogprobParams, SamplingParams from ..sampling_params import LogprobParams, SamplingParams
from .utils import ErrorResponse, has_event_loop, is_llm_response from .utils import ErrorResponse, has_event_loop, is_llm_response
@ -50,14 +51,18 @@ class LogProbsResult(NamedTuple):
class ResponseWrapper: class ResponseWrapper:
"""Wrapper of runtime response with optional outputs computed post runtime. """
1. Wrapper of runtime response with optional outputs computed post runtime.
2. A workaround to pass around RequestPerfMetrics.
""" """
def __init__(self, def __init__(self,
response: Union["PostprocWorker.Output", tllm.Response], response: Union["PostprocWorker.Output", tllm.Response],
logprobs: Optional[LogProbsResult] = None): logprobs: Optional[LogProbsResult] = None,
request_perf_metrics: Optional[dict[str, float]] = None):
self._response = response self._response = response
self.logprobs = logprobs self.logprobs = logprobs
self.request_perf_metrics = request_perf_metrics
@property @property
def _is_llm_response(self): def _is_llm_response(self):
@ -68,6 +73,14 @@ class ResponseWrapper:
response = object.__getattribute__(self, '_response') response = object.__getattribute__(self, '_response')
return getattr(response, name) return getattr(response, name)
def __getstate__(self):
return (self._response, self.logprobs, self.request_perf_metrics)
def __setstate__(self, state):
self._response = state[0]
self.logprobs = state[1]
self.request_perf_metrics = state[2]
@dataclass(slots=True) @dataclass(slots=True)
class CompletionOutput: class CompletionOutput:
@ -146,6 +159,7 @@ class GenerationResultBase:
self.disaggregated_params = None self.disaggregated_params = None
self.decoding_iter = 0 self.decoding_iter = 0
self._done = False self._done = False
self.metrics_dict = {}
if has_event_loop(): if has_event_loop():
self.aqueue = AsyncQueue() self.aqueue = AsyncQueue()
@ -201,7 +215,9 @@ class GenerationResultBase:
finish_reasons, finish_reasons,
response_tensors, response_tensors,
sequence_index, sequence_index,
logprobs_result=None): logprobs_result=None,
req_perf_metrics_dict: Optional[dict[str,
float]] = None):
""" Handle a single sequence in the response. """ """ Handle a single sequence in the response. """
seq_idx = sequence_index seq_idx = sequence_index
@ -271,6 +287,7 @@ class GenerationResultBase:
else: else:
raise ValueError( raise ValueError(
f"Unknown finish reason: {finish_reasons[src_idx]}") f"Unknown finish reason: {finish_reasons[src_idx]}")
self.record_stats(output, req_perf_metrics_dict)
@nvtx_range_debug("handle_response", @nvtx_range_debug("handle_response",
color="red", color="red",
@ -278,7 +295,9 @@ class GenerationResultBase:
def _handle_response(self, def _handle_response(self,
response: Union["PostprocWorker.Output", tllm.Response, response: Union["PostprocWorker.Output", tllm.Response,
ResponseWrapper, ErrorResponse]): ResponseWrapper, ErrorResponse]):
req_perf_metrics_dict = None
if isinstance(response, ResponseWrapper): if isinstance(response, ResponseWrapper):
req_perf_metrics_dict = response.request_perf_metrics
logprobs_result = response.logprobs logprobs_result = response.logprobs
response = response._response response = response._response
else: else:
@ -291,6 +310,8 @@ class GenerationResultBase:
self._outputs[0] = response.res self._outputs[0] = response.res
else: else:
self._outputs[0]._postprocess_result = response.res self._outputs[0]._postprocess_result = response.res
if response.metrics:
self.metrics_dict = response.metrics
if response.error: if response.error:
if self._background_error_handler is not None and ( if self._background_error_handler is not None and (
@ -303,7 +324,8 @@ class GenerationResultBase:
handler(response.error_msg) handler(response.error_msg)
response_result = response.result response_result = response.result
if hasattr(response_result, "_result"): if hasattr(response_result, "_result") and isinstance(
response_result._result, bytes):
response_result.deserialize() response_result.deserialize()
self._done = response_result.is_final self._done = response_result.is_final
@ -322,11 +344,12 @@ class GenerationResultBase:
if self.sampling_params.use_beam_search: if self.sampling_params.use_beam_search:
for beam_idx, _ in enumerate(response_result.output_token_ids): for beam_idx, _ in enumerate(response_result.output_token_ids):
self._handle_sequence(finish_reasons, response_result, self._handle_sequence(finish_reasons, response_result,
beam_idx, logprobs_result) beam_idx, logprobs_result,
req_perf_metrics_dict)
else: else:
self._handle_sequence(finish_reasons, response_result, self._handle_sequence(finish_reasons, response_result,
response_result.sequence_index, response_result.sequence_index,
logprobs_result) logprobs_result, req_perf_metrics_dict)
if response_result.context_logits is not None: if response_result.context_logits is not None:
self._context_logits = response_result.context_logits self._context_logits = response_result.context_logits
@ -342,6 +365,29 @@ class GenerationResultBase:
else: else:
raise ValueError(f"Unknown response type: {response}") raise ValueError(f"Unknown response type: {response}")
def record_stats(self,
output: CompletionOutput,
stats: Optional[dict[str, float]] = None) -> None:
"""Record the stats of the generation result.
Args:
output (CompletionOutput): The output of the generation result.
stats (Optional[dict[str, float]]): The stats of the generation result. Defaults to None.
"""
if not stats:
return
metrics_stats = {}
if output.finish_reason:
metrics_stats.update({
MetricsCollector.labelname_finish_reason:
output.finish_reason
})
processed_metrics_stat = _process_req_perf_metrics(
stats, len(output.token_ids), self.sampling_params.n > 1)
if processed_metrics_stat:
metrics_stats.update(processed_metrics_stat)
self.metrics_dict = metrics_stats
class DetokenizedGenerationResultBase(GenerationResultBase): class DetokenizedGenerationResultBase(GenerationResultBase):
''' The base class for the generation result with detokenization support. ''' ''' The base class for the generation result with detokenization support. '''
@ -688,3 +734,30 @@ def compute_logprobs(
return LogProbsResult(prompt=prompt_logprobs, return LogProbsResult(prompt=prompt_logprobs,
generation=generation_logprobs) generation=generation_logprobs)
def _process_req_perf_metrics(
req_perf_metrics_dict: Optional[dict[str, float]],
output_length: int,
is_multiple_response: bool = False) -> dict[MetricNames, float]:
stat = {}
if not req_perf_metrics_dict:
return stat
ttft = req_perf_metrics_dict.get(RequestEventTiming.FIRST_TOKEN_TIME, 0) - \
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
e2e = req_perf_metrics_dict.get(RequestEventTiming.LAST_TOKEN_TIME, 0) - \
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
request_queue_time = req_perf_metrics_dict.get(RequestEventTiming.FIRST_SCHEDULED_TIME, 0) - \
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
stat = {
MetricNames.TTFT: ttft,
MetricNames.E2E: e2e,
MetricNames.REQUEST_QUEUE_TIME: request_queue_time
}
if output_length > 1 and not is_multiple_response:
tpot = (req_perf_metrics_dict.get(
RequestEventTiming.LAST_TOKEN_TIME, 0) - req_perf_metrics_dict.get(
RequestEventTiming.FIRST_TOKEN_TIME, 0)) / (output_length - 1)
stat.update({MetricNames.TPOT: tpot})
stat = dict(filter(lambda item: item[1] > 0, stat.items()))
return stat

View File

@ -25,6 +25,7 @@ from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug, clear_sched_affinity, print_colored_debug,
print_traceback_on_error) print_traceback_on_error)
from ..lora_manager import LoraConfig, LoraManager from ..lora_manager import LoraConfig, LoraManager
from ..metrics import RequestEventTiming
from ..prompt_adapter_manager import PromptAdapterManager from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig from ..runtime import ModelConfig
from ..runtime.model_runner import _engine_config_to_model_config from ..runtime.model_runner import _engine_config_to_model_config
@ -899,10 +900,8 @@ class AwaitResponseHelper:
assert response is not None assert response is not None
queue = self.worker.return_queue(response.client_id) queue = self.worker.return_queue(response.client_id)
logprobs_result = _get_logprobs(self.worker, response, response = _maybe_wrap_response(self.worker, response,
self.worker._is_pytorch_backend) self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)
# For AsyncQueue.sync_q, we will batch the events to avoid too many # For AsyncQueue.sync_q, we will batch the events to avoid too many
# event notifications, thus put without wait here. # event notifications, thus put without wait here.
@ -940,10 +939,8 @@ class AwaitResponseHelper:
response = ErrorResponse(response.client_id, response.error_msg, response = ErrorResponse(response.client_id, response.error_msg,
response.request_id) response.request_id)
else: else:
logprobs_result = _get_logprobs(self.worker, response, response = _maybe_wrap_response(self.worker, response,
self.worker._is_pytorch_backend) self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)
_send_rsp(self.worker, _send_rsp(self.worker,
response, response,
@ -1051,3 +1048,41 @@ def _send_rsp(
worker._pop_result(response.client_id) worker._pop_result(response.client_id)
else: else:
raise ValueError(f"Unknown response type: {response}") raise ValueError(f"Unknown response type: {response}")
def _get_metrics_dict(
response: tllm.Response) -> dict[RequestEventTiming, float]:
req_perf_metrics, metrics_dict = None, {}
res = response.result
if res:
if hasattr(res, '_result'):
if result := res.get_result():
req_perf_metrics = result.request_perf_metrics
else:
req_perf_metrics = res.request_perf_metrics
if req_perf_metrics and req_perf_metrics.timing_metrics:
metrics_dict = {
RequestEventTiming.ARRIVAL_TIME:
req_perf_metrics.timing_metrics.arrival_time.total_seconds(),
RequestEventTiming.FIRST_TOKEN_TIME:
req_perf_metrics.timing_metrics.first_token_time.total_seconds(
),
RequestEventTiming.FIRST_SCHEDULED_TIME:
req_perf_metrics.timing_metrics.first_scheduled_time.
total_seconds(),
RequestEventTiming.LAST_TOKEN_TIME:
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
}
return metrics_dict
def _maybe_wrap_response(
worker,
response: tllm.Response,
is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]:
logprobs_result = _get_logprobs(worker, response, is_pytorch_backend)
req_perf_metrics = _get_metrics_dict(response)
if logprobs_result or req_perf_metrics:
response = ResponseWrapper(response, logprobs_result, req_perf_metrics)
return response

View File

@ -548,7 +548,7 @@ class BaseLLM:
if sampling_params._stream_interval is None: if sampling_params._stream_interval is None:
sampling_params._stream_interval = getattr(self.args, sampling_params._stream_interval = getattr(self.args,
"stream_interval", 1) "stream_interval", 1)
sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics
return sampling_params return sampling_params
def _check_arguments(self, prompt_len: int, query_len: int, def _check_arguments(self, prompt_len: int, query_len: int,

View File

@ -1311,6 +1311,10 @@ class BaseLlmArgs(StrictBaseModel):
status="deprecated", status="deprecated",
) )
return_perf_metrics: bool = Field(default=False,
description="Return perf metrics.",
status="prototype")
_parallel_config: Optional[object] = PrivateAttr(default=None) _parallel_config: Optional[object] = PrivateAttr(default=None)
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None) _model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
_speculative_model: Optional[str] = PrivateAttr(default=None) _speculative_model: Optional[str] = PrivateAttr(default=None)

View File

@ -0,0 +1,4 @@
from .collector import *
from .enums import *
__all__ = ["MetricsCollector", "MetricNames", "RequestEventTiming"]

View File

@ -0,0 +1,105 @@
"""Utilities for Prometheus Metrics Collection."""
import time
from typing import Dict, Optional, Union
from .enums import MetricNames
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0rc1/vllm/engine/metrics.py#L30
class MetricsCollector:
labelname_finish_reason = "finished_reason"
def __init__(self, labels: Dict[str, str]) -> None:
from prometheus_client import Counter, Histogram
self.last_log_time = time.time()
self.labels = labels
self.finish_reason_label = {
MetricsCollector.labelname_finish_reason: "unknown"
}
self.labels_with_finished_reason = {
**self.labels,
**self.finish_reason_label
}
self.counter_request_success = Counter(
name="request_success_total",
documentation="Count of successfully processed requests.",
labelnames=self.labels_with_finished_reason.keys())
self.histogram_e2e_time_request = Histogram(
name="e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
buckets=[
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
],
labelnames=self.labels.keys())
self.histogram_time_to_first_token = Histogram(
name="time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0,
2560.0
],
labelnames=self.labels.keys())
self.histogram_time_per_output_token = Histogram(
name="time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
],
labelnames=self.labels.keys())
self.histogram_queue_time_request = Histogram(
name="request_queue_time_seconds",
documentation=
"Histogram of time spent in WAITING phase for request.",
buckets=[
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
],
labelnames=self.labels.keys())
def _label_merge(self, labels: Dict[str, str]) -> Dict[str, str]:
if labels is None or len(labels) == 0:
return self.labels
return {**self.labels, **labels}
def _log_counter(self, counter, labels: Dict[str, str],
data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self._label_merge(labels)).inc(data)
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
# Convenience function for logging to histogram.
histogram.labels(**self.labels).observe(data)
def log_request_success(self, data: Union[int, float],
labels: Dict[str, str]) -> None:
self._log_counter(self.counter_request_success, labels, data)
self.last_log_time = time.time()
def log_histogram(self, data: Optional[dict[str, float]]) -> None:
if e2e := data.get(MetricNames.E2E, 0):
self._log_histogram(self.histogram_e2e_time_request, e2e)
if ttft := data.get(MetricNames.TTFT, 0):
self._log_histogram(self.histogram_time_to_first_token, ttft)
if tpot := data.get(MetricNames.TPOT, 0):
self._log_histogram(self.histogram_time_per_output_token, tpot)
if request_queue_time := data.get(MetricNames.REQUEST_QUEUE_TIME, 0):
self._log_histogram(self.histogram_queue_time_request,
request_queue_time)
self.last_log_time = time.time()
def log_metrics_dict(self, metrics_dict: dict[str, float]) -> None:
if finish_reason := metrics_dict.get(
MetricsCollector.labelname_finish_reason):
self.log_request_success(
1, {MetricsCollector.labelname_finish_reason: finish_reason})
self.log_histogram(metrics_dict)

View File

@ -0,0 +1,15 @@
from enum import Enum
class MetricNames(Enum):
TTFT = "ttft"
TPOT = "tpot"
E2E = "e2e"
REQUEST_QUEUE_TIME = "request_queue_time"
class RequestEventTiming(Enum):
ARRIVAL_TIME = "arrival_time"
FIRST_TOKEN_TIME = "first_token_time" # nosec: B105
FIRST_SCHEDULED_TIME = "first_scheduled_time"
LAST_TOKEN_TIME = "last_token_time" # nosec: B105

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
import asyncio import asyncio
import os import os
import re
import signal import signal
import traceback import traceback
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -13,6 +14,7 @@ import uvicorn
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Mount
from transformers import AutoConfig, AutoProcessor from transformers import AutoConfig, AutoProcessor
from tensorrt_llm._tensorrt_engine import LLM from tensorrt_llm._tensorrt_engine import LLM
@ -25,6 +27,7 @@ from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole
from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.metrics.collector import MetricsCollector
from tensorrt_llm.serve.chat_utils import (check_multiple_response, from tensorrt_llm.serve.chat_utils import (check_multiple_response,
parse_chat_messages_coroutines) parse_chat_messages_coroutines)
from tensorrt_llm.serve.metadata_server import create_metadata_server from tensorrt_llm.serve.metadata_server import create_metadata_server
@ -42,7 +45,7 @@ from tensorrt_llm.serve.postprocess_handlers import (
completion_stream_post_processor) completion_stream_post_processor)
from tensorrt_llm.version import __version__ as VERSION from tensorrt_llm.version import __version__ as VERSION
from .._utils import nvtx_mark from .._utils import nvtx_mark, set_prometheus_multiproc_dir
# yapf: enale # yapf: enale
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
@ -78,6 +81,13 @@ class OpenAIServer:
self.model = model_dir.name self.model = model_dir.name
else: else:
self.model = model self.model = model
self.metrics_collector = None
if self.llm.args.return_perf_metrics:
set_prometheus_multiproc_dir()
self.metrics_collector = MetricsCollector({
"model_name": "undefined",
"engine_type": "undefined"
})
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@ -151,6 +161,32 @@ class OpenAIServer:
self.app.add_api_route("/v1/chat/completions", self.app.add_api_route("/v1/chat/completions",
self.openai_chat, self.openai_chat,
methods=["POST"]) methods=["POST"])
if self.llm.args.return_perf_metrics:
# register /prometheus/metrics
self.mount_metrics()
def mount_metrics(self):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
Instrumentator(
should_group_status_codes=False,
should_respect_env_var=True,
excluded_handlers=[
".*"
],
registry=registry,
).add().instrument(self.app).expose(self.app)
metrics_app = make_asgi_app(registry=registry)
metrics_route = Mount("/prometheus/metrics", metrics_app)
metrics_route.path_regex = re.compile("^/prometheus/metrics(?P<path>.*)$")
self.app.routes.append(metrics_route)
async def health(self) -> Response: async def health(self) -> Response:
return Response(status_code=200) return Response(status_code=200)
@ -228,6 +264,8 @@ class OpenAIServer:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
async for res in promise: async for res in promise:
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
if res.finished and self.metrics_collector:
self.metrics_collector.log_metrics_dict(res.metrics_dict)
for pp_res in pp_results: for pp_res in pp_results:
yield pp_res yield pp_res
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
@ -245,6 +283,8 @@ class OpenAIServer:
# Add prompt_tokens_ids to the response # Add prompt_tokens_ids to the response
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
chat_response.prompt_token_ids = promise.prompt_token_ids chat_response.prompt_token_ids = promise.prompt_token_ids
if promise.finished and self.metrics_collector:
self.metrics_collector.log_metrics_dict(promise.metrics_dict)
return chat_response return chat_response
try: try:
@ -337,6 +377,8 @@ class OpenAIServer:
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
# Include prompt token ids for context-only requests # Include prompt token ids for context-only requests
pp_result.prompt_token_ids = response.prompt_token_ids pp_result.prompt_token_ids = response.prompt_token_ids
if response.finished and self.metrics_collector:
self.metrics_collector.log_metrics_dict(response.metrics_dict)
return pp_result return pp_result
def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse: def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse:
@ -372,6 +414,8 @@ class OpenAIServer:
pp_result = post_processor(output, args) pp_result = post_processor(output, args)
else: else:
pp_result = output.outputs[0]._postprocess_result pp_result = output.outputs[0]._postprocess_result
if output.finished and self.metrics_collector:
self.metrics_collector.log_metrics_dict(output.metrics_dict)
for pp_res in pp_result: for pp_res in pp_result:
yield pp_res yield pp_res

View File

@ -1497,6 +1497,13 @@ def test_openai_chat_with_logit_bias(llm_root, llm_venv, sampler: str):
]) ])
def test_openai_prometheus(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(
["-m", "pytest",
str(test_root / "_test_openai_prometheus.py")])
def test_openai_lora(llm_root, llm_venv): def test_openai_lora(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps" test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")]) llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")])

View File

@ -25,6 +25,7 @@ l0_a10:
- test_e2e.py::test_openai_chat_structural_tag_example - test_e2e.py::test_openai_chat_structural_tag_example
- test_e2e.py::test_openai_chat_json_example - test_e2e.py::test_openai_chat_json_example
- test_e2e.py::test_openai_chat_multimodal_example - test_e2e.py::test_openai_chat_multimodal_example
- test_e2e.py::test_openai_prometheus
- test_e2e.py::test_openai_lora - test_e2e.py::test_openai_lora
- test_e2e.py::test_trtllm_serve_multimodal_example - test_e2e.py::test_trtllm_serve_multimodal_example
- test_e2e.py::test_trtllm_serve_lora_example - test_e2e.py::test_trtllm_serve_lora_example

View File

@ -27,6 +27,10 @@ methods:
annotation: Optional[int] annotation: Optional[int]
default: null default: null
status: prototype status: prototype
return_perf_metrics:
annotation: bool
default: False
status: prototype
# Bindings and mirrored configs # Bindings and mirrored configs
peft_cache_config: peft_cache_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.PeftCacheConfig] annotation: Optional[tensorrt_llm.llmapi.llm_args.PeftCacheConfig]

View File

@ -11,4 +11,13 @@ methods:
clear_logprob_params: clear_logprob_params:
parameters: {} parameters: {}
return_annotation: None return_annotation: None
record_stats:
parameters:
output:
annotation: tensorrt_llm.executor.result.CompletionOutput
default: inspect._empty
stats:
annotation: Optional[dict[str, float]]
default: None
return_annotation: None
properties: {} properties: {}

View File

@ -0,0 +1,67 @@
import logging
import os
import tempfile
from urllib.request import urlopen
import pytest
import yaml
from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"])
def model_name():
return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file(request):
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {"return_perf_metrics": True}
with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
yield temp_file_path
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
@pytest.fixture(scope="module")
def server(model_name: str,
temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer:
model_path = get_model_path(model_name)
args = ["--backend", "pytorch", "--tp_size", "1"]
args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file])
logger.info(f"Starting server, model: {model_name}, args: {args}")
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server
logger.info("Tests completed, shutting down server")
def test_metrics_endpoint(server: RemoteOpenAIServer):
client = server.get_client()
client.completions.create(
model="Server",
prompt="Hello, my name is",
max_tokens=25,
stream=False,
)
response = urlopen(f'{server.url_root}/prometheus/metrics')
assert response.status is 200
data = response.read().decode("utf-8")
assert "request_success_total" in data
assert "e2e_request_latency_seconds" in data
assert "time_to_first_token_seconds" in data
assert "request_queue_time_seconds" in data

View File

@ -6,6 +6,7 @@ from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from tensorrt_llm.metrics import MetricNames
from tensorrt_llm.sampling_params import SamplingParams from tensorrt_llm.sampling_params import SamplingParams
# isort: off # isort: off
@ -195,6 +196,27 @@ def test_llm_perf_metrics():
assert perf_metrics.last_iter == perf_metrics.iter assert perf_metrics.last_iter == perf_metrics.iter
def test_llm_prometheus():
test_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=10, temperature=0.8, top_p=0.95)
llm = LLM(model=llama_model_path,
return_perf_metrics=True,
kv_cache_config=global_kvcache_config)
for test_prompt in test_prompts:
request_output = llm.generate(test_prompt, sampling_params)
assert request_output.metrics_dict is not None
assert MetricNames.REQUEST_QUEUE_TIME in request_output.metrics_dict
assert MetricNames.TPOT in request_output.metrics_dict
assert MetricNames.TTFT in request_output.metrics_dict
assert MetricNames.E2E in request_output.metrics_dict
assert request_output.outputs is not None
@pytest.mark.parametrize("streaming", [True, False]) @pytest.mark.parametrize("streaming", [True, False])
def test_llm_with_postprocess_parallel_and_result_handler(streaming): def test_llm_with_postprocess_parallel_and_result_handler(streaming):
run_llm_with_postprocess_parallel_and_result_handler(streaming, run_llm_with_postprocess_parallel_and_result_handler(streaming,