[None][feat] Core Metrics Implementation (#5785)

Signed-off-by: Ye Zhang <zhysishu@gmail.com>
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
This commit is contained in:
Ye Zhang 2025-08-09 14:48:53 +08:00 committed by GitHub
parent 97787883c3
commit bcf5ec0c9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 442 additions and 21 deletions

View File

@ -29,6 +29,8 @@ nvidia-modelopt[torch]~=0.33.0
nvidia-nccl-cu12
nvidia-cuda-nvrtc-cu12
transformers==4.55.0
prometheus_client
prometheus_fastapi_instrumentator
pydantic>=2.9.1
pydantic-settings[yaml]
omegaconf

View File

@ -250,6 +250,12 @@ class LlmResult:
self._result = tensorrt_llm.bindings.executor.deserialize_result(
self._result)
def get_result(self):
if tmp_res := tensorrt_llm.bindings.executor.deserialize_result(
self._result):
return tmp_res
return None
@dataclass
class LlmResponse:

View File

@ -20,6 +20,7 @@ import linecache
import math
import os
import struct
import tempfile
import trace
import weakref
from contextlib import contextmanager
@ -1112,3 +1113,17 @@ def is_multi_device_enable():
the number of devices
"""
return local_mpi_size() > 1
def set_prometheus_multiproc_dir() -> object:
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266
global prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
logger.info("User set PROMETHEUS_MULTIPROC_DIR detected.")
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"])
else:
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
logger.info(
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")

View File

@ -3,7 +3,7 @@ import traceback
from collections import deque
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional)
Optional, Union)
import zmq
import zmq.asyncio
@ -18,7 +18,7 @@ from .utils import is_llm_response
if TYPE_CHECKING:
from .result import (DetokenizedGenerationResultBase, GenerationResult,
GenerationResultBase)
GenerationResultBase, ResponseWrapper)
__all__ = [
"PostprocWorker",
@ -57,7 +57,7 @@ class PostprocWorker:
@dataclass
class Input:
rsp: "tllm.Response"
rsp: Union["tllm.Response", "ResponseWrapper"]
# The information necessary for creating a GenerationResult in the first Input for each request
sampling_params: Optional[SamplingParams] = None
@ -69,6 +69,7 @@ class PostprocWorker:
res: Any
is_final: bool
error: str = ""
metrics: Optional[dict[str, float]] = None
def __init__(
self,
@ -118,7 +119,9 @@ class PostprocWorker:
streaming=inp.streaming,
tokenizer=tokenizer)
async def _handle_input(self, input: "PostprocWorker.Input") -> Any:
async def _handle_input(
self, input: Union["PostprocWorker.Input", "ResponseWrapper"]
) -> [Any, Optional[dict[str, float]]]:
''' Handle a single response from await_response worker. '''
if input.rsp.result.context_logits is not None or \
input.rsp.result.generation_logits is not None:
@ -139,6 +142,7 @@ class PostprocWorker:
record._handle_response(input.rsp) # inplace
# Left the result_handler determine the final output dtype.
# NOTE: This will change the CompletionOutput._postprocess_result
metrics_dict = record.metrics_dict
if postproc_params := record.postproc_params:
result_handler, args = postproc_params.post_processor, postproc_params.postproc_args
args.tokenizer = self._tokenizer
@ -150,7 +154,7 @@ class PostprocWorker:
# TODO: Keep only the diff token_ids and text in streaming mode when
# result_handler is not set
return out
return out, metrics_dict
async def _batched_put(self):
''' Batched IPC send. '''
@ -173,8 +177,12 @@ class PostprocWorker:
client_id = inp.rsp.client_id
is_final = inp.rsp.result.is_final if is_llm_response(
inp.rsp) else True
res = await self._handle_input(inp)
batch.append(PostprocWorker.Output(client_id, res, is_final))
res, metrics = await self._handle_input(inp)
batch.append(
PostprocWorker.Output(client_id=client_id,
res=res,
is_final=is_final,
metrics=metrics))
if is_final:
self._records.pop(client_id)

View File

@ -15,6 +15,7 @@ from ..bindings import executor as tllm
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import AsyncQueue
from ..metrics import MetricNames, MetricsCollector, RequestEventTiming
from ..sampling_params import LogprobParams, SamplingParams
from .utils import ErrorResponse, has_event_loop, is_llm_response
@ -50,14 +51,18 @@ class LogProbsResult(NamedTuple):
class ResponseWrapper:
"""Wrapper of runtime response with optional outputs computed post runtime.
"""
1. Wrapper of runtime response with optional outputs computed post runtime.
2. A workaround to pass around RequestPerfMetrics.
"""
def __init__(self,
response: Union["PostprocWorker.Output", tllm.Response],
logprobs: Optional[LogProbsResult] = None):
logprobs: Optional[LogProbsResult] = None,
request_perf_metrics: Optional[dict[str, float]] = None):
self._response = response
self.logprobs = logprobs
self.request_perf_metrics = request_perf_metrics
@property
def _is_llm_response(self):
@ -68,6 +73,14 @@ class ResponseWrapper:
response = object.__getattribute__(self, '_response')
return getattr(response, name)
def __getstate__(self):
return (self._response, self.logprobs, self.request_perf_metrics)
def __setstate__(self, state):
self._response = state[0]
self.logprobs = state[1]
self.request_perf_metrics = state[2]
@dataclass(slots=True)
class CompletionOutput:
@ -146,6 +159,7 @@ class GenerationResultBase:
self.disaggregated_params = None
self.decoding_iter = 0
self._done = False
self.metrics_dict = {}
if has_event_loop():
self.aqueue = AsyncQueue()
@ -201,7 +215,9 @@ class GenerationResultBase:
finish_reasons,
response_tensors,
sequence_index,
logprobs_result=None):
logprobs_result=None,
req_perf_metrics_dict: Optional[dict[str,
float]] = None):
""" Handle a single sequence in the response. """
seq_idx = sequence_index
@ -271,6 +287,7 @@ class GenerationResultBase:
else:
raise ValueError(
f"Unknown finish reason: {finish_reasons[src_idx]}")
self.record_stats(output, req_perf_metrics_dict)
@nvtx_range_debug("handle_response",
color="red",
@ -278,7 +295,9 @@ class GenerationResultBase:
def _handle_response(self,
response: Union["PostprocWorker.Output", tllm.Response,
ResponseWrapper, ErrorResponse]):
req_perf_metrics_dict = None
if isinstance(response, ResponseWrapper):
req_perf_metrics_dict = response.request_perf_metrics
logprobs_result = response.logprobs
response = response._response
else:
@ -291,6 +310,8 @@ class GenerationResultBase:
self._outputs[0] = response.res
else:
self._outputs[0]._postprocess_result = response.res
if response.metrics:
self.metrics_dict = response.metrics
if response.error:
if self._background_error_handler is not None and (
@ -303,7 +324,8 @@ class GenerationResultBase:
handler(response.error_msg)
response_result = response.result
if hasattr(response_result, "_result"):
if hasattr(response_result, "_result") and isinstance(
response_result._result, bytes):
response_result.deserialize()
self._done = response_result.is_final
@ -322,11 +344,12 @@ class GenerationResultBase:
if self.sampling_params.use_beam_search:
for beam_idx, _ in enumerate(response_result.output_token_ids):
self._handle_sequence(finish_reasons, response_result,
beam_idx, logprobs_result)
beam_idx, logprobs_result,
req_perf_metrics_dict)
else:
self._handle_sequence(finish_reasons, response_result,
response_result.sequence_index,
logprobs_result)
logprobs_result, req_perf_metrics_dict)
if response_result.context_logits is not None:
self._context_logits = response_result.context_logits
@ -342,6 +365,29 @@ class GenerationResultBase:
else:
raise ValueError(f"Unknown response type: {response}")
def record_stats(self,
output: CompletionOutput,
stats: Optional[dict[str, float]] = None) -> None:
"""Record the stats of the generation result.
Args:
output (CompletionOutput): The output of the generation result.
stats (Optional[dict[str, float]]): The stats of the generation result. Defaults to None.
"""
if not stats:
return
metrics_stats = {}
if output.finish_reason:
metrics_stats.update({
MetricsCollector.labelname_finish_reason:
output.finish_reason
})
processed_metrics_stat = _process_req_perf_metrics(
stats, len(output.token_ids), self.sampling_params.n > 1)
if processed_metrics_stat:
metrics_stats.update(processed_metrics_stat)
self.metrics_dict = metrics_stats
class DetokenizedGenerationResultBase(GenerationResultBase):
''' The base class for the generation result with detokenization support. '''
@ -688,3 +734,30 @@ def compute_logprobs(
return LogProbsResult(prompt=prompt_logprobs,
generation=generation_logprobs)
def _process_req_perf_metrics(
req_perf_metrics_dict: Optional[dict[str, float]],
output_length: int,
is_multiple_response: bool = False) -> dict[MetricNames, float]:
stat = {}
if not req_perf_metrics_dict:
return stat
ttft = req_perf_metrics_dict.get(RequestEventTiming.FIRST_TOKEN_TIME, 0) - \
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
e2e = req_perf_metrics_dict.get(RequestEventTiming.LAST_TOKEN_TIME, 0) - \
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
request_queue_time = req_perf_metrics_dict.get(RequestEventTiming.FIRST_SCHEDULED_TIME, 0) - \
req_perf_metrics_dict.get(RequestEventTiming.ARRIVAL_TIME, 0)
stat = {
MetricNames.TTFT: ttft,
MetricNames.E2E: e2e,
MetricNames.REQUEST_QUEUE_TIME: request_queue_time
}
if output_length > 1 and not is_multiple_response:
tpot = (req_perf_metrics_dict.get(
RequestEventTiming.LAST_TOKEN_TIME, 0) - req_perf_metrics_dict.get(
RequestEventTiming.FIRST_TOKEN_TIME, 0)) / (output_length - 1)
stat.update({MetricNames.TPOT: tpot})
stat = dict(filter(lambda item: item[1] > 0, stat.items()))
return stat

View File

@ -25,6 +25,7 @@ from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug,
print_traceback_on_error)
from ..lora_manager import LoraConfig, LoraManager
from ..metrics import RequestEventTiming
from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig
from ..runtime.model_runner import _engine_config_to_model_config
@ -899,10 +900,8 @@ class AwaitResponseHelper:
assert response is not None
queue = self.worker.return_queue(response.client_id)
logprobs_result = _get_logprobs(self.worker, response,
response = _maybe_wrap_response(self.worker, response,
self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)
# For AsyncQueue.sync_q, we will batch the events to avoid too many
# event notifications, thus put without wait here.
@ -940,10 +939,8 @@ class AwaitResponseHelper:
response = ErrorResponse(response.client_id, response.error_msg,
response.request_id)
else:
logprobs_result = _get_logprobs(self.worker, response,
response = _maybe_wrap_response(self.worker, response,
self.worker._is_pytorch_backend)
if logprobs_result:
response = ResponseWrapper(response, logprobs_result)
_send_rsp(self.worker,
response,
@ -1051,3 +1048,41 @@ def _send_rsp(
worker._pop_result(response.client_id)
else:
raise ValueError(f"Unknown response type: {response}")
def _get_metrics_dict(
response: tllm.Response) -> dict[RequestEventTiming, float]:
req_perf_metrics, metrics_dict = None, {}
res = response.result
if res:
if hasattr(res, '_result'):
if result := res.get_result():
req_perf_metrics = result.request_perf_metrics
else:
req_perf_metrics = res.request_perf_metrics
if req_perf_metrics and req_perf_metrics.timing_metrics:
metrics_dict = {
RequestEventTiming.ARRIVAL_TIME:
req_perf_metrics.timing_metrics.arrival_time.total_seconds(),
RequestEventTiming.FIRST_TOKEN_TIME:
req_perf_metrics.timing_metrics.first_token_time.total_seconds(
),
RequestEventTiming.FIRST_SCHEDULED_TIME:
req_perf_metrics.timing_metrics.first_scheduled_time.
total_seconds(),
RequestEventTiming.LAST_TOKEN_TIME:
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
}
return metrics_dict
def _maybe_wrap_response(
worker,
response: tllm.Response,
is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]:
logprobs_result = _get_logprobs(worker, response, is_pytorch_backend)
req_perf_metrics = _get_metrics_dict(response)
if logprobs_result or req_perf_metrics:
response = ResponseWrapper(response, logprobs_result, req_perf_metrics)
return response

View File

@ -548,7 +548,7 @@ class BaseLLM:
if sampling_params._stream_interval is None:
sampling_params._stream_interval = getattr(self.args,
"stream_interval", 1)
sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics
return sampling_params
def _check_arguments(self, prompt_len: int, query_len: int,

View File

@ -1311,6 +1311,10 @@ class BaseLlmArgs(StrictBaseModel):
status="deprecated",
)
return_perf_metrics: bool = Field(default=False,
description="Return perf metrics.",
status="prototype")
_parallel_config: Optional[object] = PrivateAttr(default=None)
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
_speculative_model: Optional[str] = PrivateAttr(default=None)

View File

@ -0,0 +1,4 @@
from .collector import *
from .enums import *
__all__ = ["MetricsCollector", "MetricNames", "RequestEventTiming"]

View File

@ -0,0 +1,105 @@
"""Utilities for Prometheus Metrics Collection."""
import time
from typing import Dict, Optional, Union
from .enums import MetricNames
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0rc1/vllm/engine/metrics.py#L30
class MetricsCollector:
labelname_finish_reason = "finished_reason"
def __init__(self, labels: Dict[str, str]) -> None:
from prometheus_client import Counter, Histogram
self.last_log_time = time.time()
self.labels = labels
self.finish_reason_label = {
MetricsCollector.labelname_finish_reason: "unknown"
}
self.labels_with_finished_reason = {
**self.labels,
**self.finish_reason_label
}
self.counter_request_success = Counter(
name="request_success_total",
documentation="Count of successfully processed requests.",
labelnames=self.labels_with_finished_reason.keys())
self.histogram_e2e_time_request = Histogram(
name="e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
buckets=[
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
],
labelnames=self.labels.keys())
self.histogram_time_to_first_token = Histogram(
name="time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0,
2560.0
],
labelnames=self.labels.keys())
self.histogram_time_per_output_token = Histogram(
name="time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
],
labelnames=self.labels.keys())
self.histogram_queue_time_request = Histogram(
name="request_queue_time_seconds",
documentation=
"Histogram of time spent in WAITING phase for request.",
buckets=[
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0
],
labelnames=self.labels.keys())
def _label_merge(self, labels: Dict[str, str]) -> Dict[str, str]:
if labels is None or len(labels) == 0:
return self.labels
return {**self.labels, **labels}
def _log_counter(self, counter, labels: Dict[str, str],
data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self._label_merge(labels)).inc(data)
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
# Convenience function for logging to histogram.
histogram.labels(**self.labels).observe(data)
def log_request_success(self, data: Union[int, float],
labels: Dict[str, str]) -> None:
self._log_counter(self.counter_request_success, labels, data)
self.last_log_time = time.time()
def log_histogram(self, data: Optional[dict[str, float]]) -> None:
if e2e := data.get(MetricNames.E2E, 0):
self._log_histogram(self.histogram_e2e_time_request, e2e)
if ttft := data.get(MetricNames.TTFT, 0):
self._log_histogram(self.histogram_time_to_first_token, ttft)
if tpot := data.get(MetricNames.TPOT, 0):
self._log_histogram(self.histogram_time_per_output_token, tpot)
if request_queue_time := data.get(MetricNames.REQUEST_QUEUE_TIME, 0):
self._log_histogram(self.histogram_queue_time_request,
request_queue_time)
self.last_log_time = time.time()
def log_metrics_dict(self, metrics_dict: dict[str, float]) -> None:
if finish_reason := metrics_dict.get(
MetricsCollector.labelname_finish_reason):
self.log_request_success(
1, {MetricsCollector.labelname_finish_reason: finish_reason})
self.log_histogram(metrics_dict)

View File

@ -0,0 +1,15 @@
from enum import Enum
class MetricNames(Enum):
TTFT = "ttft"
TPOT = "tpot"
E2E = "e2e"
REQUEST_QUEUE_TIME = "request_queue_time"
class RequestEventTiming(Enum):
ARRIVAL_TIME = "arrival_time"
FIRST_TOKEN_TIME = "first_token_time" # nosec: B105
FIRST_SCHEDULED_TIME = "first_scheduled_time"
LAST_TOKEN_TIME = "last_token_time" # nosec: B105

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python
import asyncio
import os
import re
import signal
import traceback
from contextlib import asynccontextmanager
@ -13,6 +14,7 @@ import uvicorn
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Mount
from transformers import AutoConfig, AutoProcessor
from tensorrt_llm._tensorrt_engine import LLM
@ -25,6 +27,7 @@ from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.metrics.collector import MetricsCollector
from tensorrt_llm.serve.chat_utils import (check_multiple_response,
parse_chat_messages_coroutines)
from tensorrt_llm.serve.metadata_server import create_metadata_server
@ -42,7 +45,7 @@ from tensorrt_llm.serve.postprocess_handlers import (
completion_stream_post_processor)
from tensorrt_llm.version import __version__ as VERSION
from .._utils import nvtx_mark
from .._utils import nvtx_mark, set_prometheus_multiproc_dir
# yapf: enale
TIMEOUT_KEEP_ALIVE = 5 # seconds.
@ -78,6 +81,13 @@ class OpenAIServer:
self.model = model_dir.name
else:
self.model = model
self.metrics_collector = None
if self.llm.args.return_perf_metrics:
set_prometheus_multiproc_dir()
self.metrics_collector = MetricsCollector({
"model_name": "undefined",
"engine_type": "undefined"
})
@asynccontextmanager
async def lifespan(app: FastAPI):
@ -151,6 +161,32 @@ class OpenAIServer:
self.app.add_api_route("/v1/chat/completions",
self.openai_chat,
methods=["POST"])
if self.llm.args.return_perf_metrics:
# register /prometheus/metrics
self.mount_metrics()
def mount_metrics(self):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
Instrumentator(
should_group_status_codes=False,
should_respect_env_var=True,
excluded_handlers=[
".*"
],
registry=registry,
).add().instrument(self.app).expose(self.app)
metrics_app = make_asgi_app(registry=registry)
metrics_route = Mount("/prometheus/metrics", metrics_app)
metrics_route.path_regex = re.compile("^/prometheus/metrics(?P<path>.*)$")
self.app.routes.append(metrics_route)
async def health(self) -> Response:
return Response(status_code=200)
@ -228,6 +264,8 @@ class OpenAIServer:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
async for res in promise:
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
if res.finished and self.metrics_collector:
self.metrics_collector.log_metrics_dict(res.metrics_dict)
for pp_res in pp_results:
yield pp_res
yield "data: [DONE]\n\n"
@ -245,6 +283,8 @@ class OpenAIServer:
# Add prompt_tokens_ids to the response
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
chat_response.prompt_token_ids = promise.prompt_token_ids
if promise.finished and self.metrics_collector:
self.metrics_collector.log_metrics_dict(promise.metrics_dict)
return chat_response
try:
@ -337,6 +377,8 @@ class OpenAIServer:
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
# Include prompt token ids for context-only requests
pp_result.prompt_token_ids = response.prompt_token_ids
if response.finished and self.metrics_collector:
self.metrics_collector.log_metrics_dict(response.metrics_dict)
return pp_result
def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse:
@ -372,6 +414,8 @@ class OpenAIServer:
pp_result = post_processor(output, args)
else:
pp_result = output.outputs[0]._postprocess_result
if output.finished and self.metrics_collector:
self.metrics_collector.log_metrics_dict(output.metrics_dict)
for pp_res in pp_result:
yield pp_res

View File

@ -1497,6 +1497,13 @@ def test_openai_chat_with_logit_bias(llm_root, llm_venv, sampler: str):
])
def test_openai_prometheus(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(
["-m", "pytest",
str(test_root / "_test_openai_prometheus.py")])
def test_openai_lora(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")])

View File

@ -25,6 +25,7 @@ l0_a10:
- test_e2e.py::test_openai_chat_structural_tag_example
- test_e2e.py::test_openai_chat_json_example
- test_e2e.py::test_openai_chat_multimodal_example
- test_e2e.py::test_openai_prometheus
- test_e2e.py::test_openai_lora
- test_e2e.py::test_trtllm_serve_multimodal_example
- test_e2e.py::test_trtllm_serve_lora_example

View File

@ -27,6 +27,10 @@ methods:
annotation: Optional[int]
default: null
status: prototype
return_perf_metrics:
annotation: bool
default: False
status: prototype
# Bindings and mirrored configs
peft_cache_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.PeftCacheConfig]

View File

@ -11,4 +11,13 @@ methods:
clear_logprob_params:
parameters: {}
return_annotation: None
record_stats:
parameters:
output:
annotation: tensorrt_llm.executor.result.CompletionOutput
default: inspect._empty
stats:
annotation: Optional[dict[str, float]]
default: None
return_annotation: None
properties: {}

View File

@ -0,0 +1,67 @@
import logging
import os
import tempfile
from urllib.request import urlopen
import pytest
import yaml
from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"])
def model_name():
return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file(request):
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {"return_perf_metrics": True}
with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
yield temp_file_path
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
@pytest.fixture(scope="module")
def server(model_name: str,
temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer:
model_path = get_model_path(model_name)
args = ["--backend", "pytorch", "--tp_size", "1"]
args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file])
logger.info(f"Starting server, model: {model_name}, args: {args}")
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server
logger.info("Tests completed, shutting down server")
def test_metrics_endpoint(server: RemoteOpenAIServer):
client = server.get_client()
client.completions.create(
model="Server",
prompt="Hello, my name is",
max_tokens=25,
stream=False,
)
response = urlopen(f'{server.url_root}/prometheus/metrics')
assert response.status is 200
data = response.read().decode("utf-8")
assert "request_success_total" in data
assert "e2e_request_latency_seconds" in data
assert "time_to_first_token_seconds" in data
assert "request_queue_time_seconds" in data

View File

@ -6,6 +6,7 @@ from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from tensorrt_llm.metrics import MetricNames
from tensorrt_llm.sampling_params import SamplingParams
# isort: off
@ -195,6 +196,27 @@ def test_llm_perf_metrics():
assert perf_metrics.last_iter == perf_metrics.iter
def test_llm_prometheus():
test_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=10, temperature=0.8, top_p=0.95)
llm = LLM(model=llama_model_path,
return_perf_metrics=True,
kv_cache_config=global_kvcache_config)
for test_prompt in test_prompts:
request_output = llm.generate(test_prompt, sampling_params)
assert request_output.metrics_dict is not None
assert MetricNames.REQUEST_QUEUE_TIME in request_output.metrics_dict
assert MetricNames.TPOT in request_output.metrics_dict
assert MetricNames.TTFT in request_output.metrics_dict
assert MetricNames.E2E in request_output.metrics_dict
assert request_output.outputs is not None
@pytest.mark.parametrize("streaming", [True, False])
def test_llm_with_postprocess_parallel_and_result_handler(streaming):
run_llm_with_postprocess_parallel_and_result_handler(streaming,