[None][feat] Add opentelemetry tracing (#5897)

Signed-off-by: Zhang Haotong <zhanghaotong.zht@antgroup.com>
Signed-off-by: zhanghaotong <zhanghaotong.zht@antgroup.com>
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Zhang Haotong <zhanghaotong.zht@alibaba-inc.com>
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
This commit is contained in:
zhanghaotong 2025-10-27 18:51:07 +08:00 committed by GitHub
parent ce0d76135d
commit 1026069a2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 591 additions and 35 deletions

View File

@ -0,0 +1,85 @@
# OpenTelemetry Integration Guide
This guide explains how to setup OpenTelemetry tracing in TensorRT-LLM to monitor and debug your LLM inference services.
## Install OpenTelemetry
Install the required OpenTelemetry packages:
```bash
pip install \
'opentelemetry-sdk' \
'opentelemetry-api' \
'opentelemetry-exporter-otlp' \
'opentelemetry-semantic-conventions-ai'
```
## Start Jaeger
You can start Jaeger with Docker:
```bash
docker run --rm --name jaeger \
-e COLLECTOR_ZIPKIN_HOST_PORT=:9411 \
-p 6831:6831/udp \
-p 6832:6832/udp \
-p 5778:5778 \
-p 16686:16686 \
-p 4317:4317 \
-p 4318:4318 \
-p 14250:14250 \
-p 14268:14268 \
-p 14269:14269 \
-p 9411:9411 \
jaegertracing/all-in-one:1.57.0
```
Or run the jaeger-all-in-one(.exe) executable from [the binary distribution archives](https://www.jaegertracing.io/download/):
```bash
jaeger-all-in-one --collector.zipkin.host-port=:9411
```
## Setup environment variables and run TensorRT-LLM
Set up the environment variables:
```bash
export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger)
export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=grpc
export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317
export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true
export OTEL_SERVICE_NAME="trt-server"
```
Then run TensorRT-LLM with OpenTelemetry, and make sure to set `return_perf_metrics` to true in the model configuration:
```bash
trtllm-serve models/Qwen3-8B/ --otlp_traces_endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
```
## Send requests and find traces in Jaeger
You can send a request to the server and view the traces in [Jaeger UI](http://localhost:16686/).
The traces should be visible under the service name "trt-server".
## Configuration for Disaggregated Serving
For disaggregated serving scenarios, the configuration for ctx server and gen server remains the same as the standalone model. For the proxy, you can configure it as follows:
```yaml
# disagg_config.yaml
hostname: 127.0.0.1
port: 8000
backend: pytorch
context_servers:
num_instances: 1
urls:
- "127.0.0.1:8001"
generation_servers:
num_instances: 1
urls:
- "127.0.0.1:8002"
otlp_config:
otlp_traces_endpoint: "grpc://0.0.0.0:4317"
```

View File

@ -28,13 +28,14 @@ from contextlib import contextmanager
from enum import EnumMeta
from functools import lru_cache, partial, wraps
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import numpy as np
import nvtx
from mpi4py import MPI
from mpi4py.util import pkl5
from packaging import version
from typing_extensions import ParamSpec
# isort: off
import torch
@ -1155,6 +1156,21 @@ def set_prometheus_multiproc_dir() -> object:
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
P = ParamSpec("P")
# From: https://stackoverflow.com/a/4104188/2749989
def run_once(f: Callable[P, None]) -> Callable[P, None]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
if not wrapper.has_run: # type: ignore[attr-defined]
wrapper.has_run = True # type: ignore[attr-defined]
return f(*args, **kwargs)
wrapper.has_run = False # type: ignore[attr-defined]
return wrapper
TORCH_PYBIND11_ABI = None

View File

@ -91,6 +91,7 @@ def get_llm_args(model: str,
trust_remote_code: bool = False,
reasoning_parser: Optional[str] = None,
fail_fast_on_attention_window_too_large: bool = False,
otlp_traces_endpoint: Optional[str] = None,
enable_chunked_prefill: bool = False,
**llm_args_extra_dict: Any):
@ -134,6 +135,7 @@ def get_llm_args(model: str,
"reasoning_parser": reasoning_parser,
"fail_fast_on_attention_window_too_large":
fail_fast_on_attention_window_too_large,
"otlp_traces_endpoint": otlp_traces_endpoint,
"enable_chunked_prefill": enable_chunked_prefill,
}
@ -322,6 +324,10 @@ class ChoiceWithAlias(click.Choice):
help=
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
)
@click.option("--otlp_traces_endpoint",
type=str,
default=None,
help="Target URL to which OpenTelemetry traces will be sent.")
@click.option("--disagg_cluster_uri",
type=str,
default=None,
@ -344,8 +350,8 @@ def serve(
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str], server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool,
enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str],
media_io_kwargs: Optional[str]):
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
"""Running an OpenAI API compatible server
MODEL: model name | HF checkpoint path | TensorRT engine path
@ -371,6 +377,7 @@ def serve(
reasoning_parser=reasoning_parser,
fail_fast_on_attention_window_too_large=
fail_fast_on_attention_window_too_large,
otlp_traces_endpoint=otlp_traces_endpoint,
enable_chunked_prefill=enable_chunked_prefill)
llm_args_extra_dict = {}

View File

@ -886,7 +886,15 @@ def _get_metrics_dict(
req_perf_metrics.timing_metrics.first_scheduled_time.
total_seconds(),
RequestEventTiming.LAST_TOKEN_TIME:
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
req_perf_metrics.timing_metrics.last_token_time.total_seconds(),
RequestEventTiming.KV_CACHE_TRANSFER_START:
req_perf_metrics.timing_metrics.kv_cache_transfer_start.
total_seconds(),
RequestEventTiming.KV_CACHE_TRANSFER_END:
req_perf_metrics.timing_metrics.kv_cache_transfer_end.
total_seconds(),
RequestEventTiming.KV_CACHE_SIZE:
req_perf_metrics.timing_metrics.kv_cache_size,
}
return metrics_dict

View File

@ -5,6 +5,7 @@ import platform
import signal
import traceback
from abc import ABC, abstractmethod
from collections.abc import Mapping
from pathlib import Path
from queue import Queue
from typing import (TYPE_CHECKING, AsyncIterable, Dict, Generator, List,
@ -123,6 +124,7 @@ class GenerationExecutor(ABC):
streaming: bool = False,
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
trace_headers: Optional[Mapping[str, str]] = None,
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
@ -150,6 +152,7 @@ class GenerationExecutor(ABC):
streaming=streaming,
kv_cache_retention_config=kv_cache_retention_config,
disaggregated_params=disaggregated_params,
trace_headers=trace_headers,
multimodal_params=multimodal_params,
scheduling_params=scheduling_params,
cache_salt_id=cache_salt_id,

View File

@ -1,4 +1,5 @@
import os
from collections.abc import Mapping
from dataclasses import dataclass
from typing import List, Optional, Union
@ -94,6 +95,7 @@ class GenerationRequest:
streaming: bool = False,
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
trace_headers: Optional[Mapping[str, str]] = None,
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
@ -123,6 +125,7 @@ class GenerationRequest:
self.kv_cache_retention_config = kv_cache_retention_config
self.id: Optional[int] = None
self.disaggregated_params = disaggregated_params
self.trace_headers = trace_headers
self.scheduling_params = scheduling_params
self.cache_salt_id = cache_salt_id
self.arrival_time = arrival_time

View File

@ -1,6 +1,7 @@
import asyncio
import json
import threading
import time
import weakref
from dataclasses import dataclass, field
from queue import Empty, Queue
@ -11,6 +12,8 @@ from weakref import WeakMethod
import torch
import torch.nn.functional as F
from tensorrt_llm.llmapi import tracing
try:
import ray
except ModuleNotFoundError:
@ -268,6 +271,7 @@ class GenerationResultBase:
self.avg_decoded_tokens_per_iter: Optional[float] = None
self._done = False
self.metrics_dict = {}
self.trace_headers: Optional[dict[str, str]] = None
if ray_queue is not None:
if has_event_loop():
@ -436,6 +440,7 @@ class GenerationResultBase:
raise ValueError(
f"Unknown finish reason: {finish_reasons[src_idx]}")
self.record_stats(output, req_perf_metrics_dict)
self.do_tracing(output, req_perf_metrics_dict)
@print_traceback_on_error
@nvtx_range_debug("handle_response",
@ -472,7 +477,7 @@ class GenerationResultBase:
self._outputs[0].disaggregated_params = disaggregated_params
if response.metrics:
self.metrics_dict = response.metrics
self.metrics_dict.update(response.metrics)
if response.error:
if self._background_error_handler is not None and (
@ -570,7 +575,110 @@ class GenerationResultBase:
stats, len(output.token_ids), self.sampling_params.n > 1)
if processed_metrics_stat:
metrics_stats.update(processed_metrics_stat)
self.metrics_dict = metrics_stats
self.metrics_dict.update(metrics_stats)
def do_tracing(
self,
output: CompletionOutput,
req_perf_metrics_dict: Optional[dict[str, float]] = None,
) -> None:
"""Perform distributed tracing for the generation request.
Args:
output (CompletionOutput): The output of the generation result.
req_perf_metrics_dict (Optional[dict[str, float]]): Request performance metrics. Defaults to None.
"""
if not tracing.global_otlp_tracer():
return
metrics_dict = self.metrics_dict
if not metrics_dict or not req_perf_metrics_dict:
# Insufficient request metrics available; trace generation aborted.
tracing.insufficient_request_metrics_warning()
return
trace_context = tracing.extract_trace_context(self.trace_headers)
sampling_params = self.sampling_params
# Since arrival_time and other timing metrics are based on different time origins,
# we need to apply corrections to align them with absolute timestamps
time_correction = time.time() - time.monotonic()
arrival_time = req_perf_metrics_dict.get(
RequestEventTiming.ARRIVAL_TIME, 0)
with tracing.global_otlp_tracer().start_as_current_span(
"llm_request",
kind=tracing.SpanKind.SERVER,
context=trace_context,
start_time=int((arrival_time + time_correction) * 1e9),
) as span:
def safe_set_attr(span, attr, value):
if value is not None:
span.set_attribute(attr, value)
safe_set_attr(span,
tracing.SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
sampling_params.temperature)
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_TOP_P,
sampling_params.top_p)
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_TOP_K,
sampling_params.top_k)
safe_set_attr(
span,
tracing.SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
sampling_params.max_tokens,
)
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_N,
sampling_params.n)
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_ID,
self.id)
if prompt_token_ids := getattr(self, "prompt_token_ids", None):
safe_set_attr(span,
tracing.SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
len(prompt_token_ids))
safe_set_attr(span,
tracing.SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
output.length)
safe_set_attr(
span, tracing.SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
metrics_dict.get(MetricNames.TTFT, -1))
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_LATENCY_E2E,
metrics_dict.get(MetricNames.E2E, -1))
safe_set_attr(span,
tracing.SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
metrics_dict.get(MetricNames.REQUEST_QUEUE_TIME, -1))
safe_set_attr(
span, tracing.SpanAttributes.GEN_AI_RESPONSE_FINISH_REASONS,
json.dumps([output.finish_reason])
if output.finish_reason else None)
safe_set_attr(
span,
tracing.SpanAttributes.GEN_AI_LATENCY_KV_CACHE_TRANSFER_TIME,
req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) -
req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0))
if req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_START,
0) and req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_END, 0):
tracing.add_event(
tracing.SpanEvents.KV_CACHE_TRANSFER_START,
timestamp=int((req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0) +
time_correction) * 1e9))
tracing.add_event(
tracing.SpanEvents.KV_CACHE_TRANSFER_END,
attributes={
"kv_cache_size":
req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_SIZE, 0)
},
timestamp=int((req_perf_metrics_dict.get(
RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) +
time_correction) * 1e9))
class DetokenizedGenerationResultBase(GenerationResultBase):
@ -688,6 +796,7 @@ class GenerationResult(GenerationResultBase):
self.disaggregated_params = disaggregated_params
# minimal sampling params needed for logprob calculation
self._logprob_params = logprob_params
self.trace_headers = generation_request.trace_headers
# for aborting the request
self._executor: Optional[weakref.ReferenceType[

View File

@ -43,6 +43,12 @@ class ConditionalDisaggConfig():
max_local_prefill_length: int = 0
@dataclass
class OtlpConfig():
otlp_traces_endpoint: Optional[
str] = None # Target URL to which OpenTelemetry traces will be sent
@dataclass
class MinimalInstances:
context_servers: int = 1 # the minimal number of context servers
@ -66,6 +72,7 @@ class DisaggServerConfig():
ctx_router_config: Optional[RouterConfig] = None
gen_router_config: Optional[RouterConfig] = None
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
otlp_config: Optional[OtlpConfig] = None
max_retries: int = 1
perf_metrics_max_requests: int = 0
disagg_cluster_config: Optional[DisaggClusterConfig] = None
@ -112,6 +119,7 @@ def extract_disagg_cfg(hostname: str = 'localhost',
context_servers: Optional[dict] = None,
generation_servers: Optional[dict] = None,
conditional_disagg_config: Optional[dict] = None,
otlp_config: Optional[dict] = None,
disagg_cluster: Optional[dict] = None,
**kwargs: Any) -> DisaggServerConfig:
context_servers = context_servers or {}
@ -149,10 +157,12 @@ def extract_disagg_cfg(hostname: str = 'localhost',
conditional_disagg_config = ConditionalDisaggConfig(
**conditional_disagg_config) if conditional_disagg_config else None
otlp_config = OtlpConfig(**otlp_config) if otlp_config else None
config = DisaggServerConfig(server_configs, hostname, port,
ctx_router_config, gen_router_config,
conditional_disagg_config, max_retries,
perf_metrics_max_requests,
conditional_disagg_config, otlp_config,
max_retries, perf_metrics_max_requests,
disagg_cluster_config)
return config

View File

@ -6,6 +6,7 @@ import socket
import tempfile
import time
import weakref
from collections.abc import Mapping
from pathlib import Path
from typing import Any, List, Literal, Optional, Sequence, Union
@ -17,6 +18,8 @@ from tensorrt_llm._utils import mpi_disabled
from tensorrt_llm.inputs.data import TextPrompt
from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams
from tensorrt_llm.inputs.registry import DefaultInputProcessor
from tensorrt_llm.llmapi import tracing
from tensorrt_llm.metrics.enums import MetricNames
from .._utils import nvtx_range_debug
from ..bindings import executor as tllm
@ -230,6 +233,15 @@ class BaseLLM:
self.mpi_session.shutdown()
raise
try:
if self.args.otlp_traces_endpoint:
tracing.init_tracer("trt.llm", self.args.otlp_traces_endpoint)
logger.info(
f"Initialized OTLP tracer successfully, endpoint: {self.args.otlp_traces_endpoint}"
)
except Exception as e:
logger.error(f"Failed to initialize OTLP tracer: {e}")
exception_handler.register(self, 'shutdown')
atexit.register(LLM._shutdown_wrapper, weakref.ref(self))
@ -338,6 +350,7 @@ class BaseLLM:
streaming: bool = False,
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
trace_headers: Optional[Mapping[str, str]] = None,
_postproc_params: Optional[PostprocParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
cache_salt: Optional[str] = None,
@ -354,6 +367,7 @@ class BaseLLM:
streaming (bool): Whether to use the streaming mode for the generation. Defaults to False.
kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, optional): Configuration for the request's retention in the KV Cache. Defaults to None.
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Disaggregated parameters. Defaults to None.
trace_headers (Mapping[str, str], optional): Trace headers. Defaults to None.
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, optional): Scheduling parameters. Defaults to None.
cache_salt (str, optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None.
Returns:
@ -486,6 +500,7 @@ class BaseLLM:
streaming=streaming,
kv_cache_retention_config=kv_cache_retention_config,
disaggregated_params=disaggregated_params,
trace_headers=trace_headers,
postproc_params=_postproc_params,
multimodal_params=multimodal_params,
scheduling_params=scheduling_params,
@ -493,6 +508,10 @@ class BaseLLM:
arrival_time=arrival_time,
)
if sampling_params.return_perf_metrics:
result.metrics_dict.update(
{MetricNames.ARRIVAL_TIMESTAMP: time.time()})
return RequestOutput._from_generation_result(result, prompt,
self.tokenizer)

View File

@ -1703,6 +1703,12 @@ class BaseLlmArgs(StrictBaseModel):
exclude=True,
alias="_mpi_session")
otlp_traces_endpoint: Optional[str] = Field(
default=None,
description="Target URL to which OpenTelemetry traces will be sent.",
alias="otlp_traces_endpoint",
status="prototype")
backend: Optional[str] = Field(
default=None,
description="The backend to use for this LLM instance.",

View File

@ -0,0 +1,227 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
__all__ = [
'SpanAttributes', 'SpanKind', 'contains_trace_headers',
'extract_trace_context', 'get_span_exporter', 'global_otlp_tracer',
'init_tracer', 'insufficient_request_metrics_warning', 'is_otel_available',
'is_tracing_enabled', 'log_tracing_disabled_warning',
'set_global_otlp_tracer', 'extract_trace_headers'
]
import functools
import os
import typing
from collections.abc import Mapping
from typing import Optional
from strenum import StrEnum
from tensorrt_llm._utils import run_once
from tensorrt_llm.logger import logger
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0rc1/vllm/tracing.py#L11
TRACE_HEADERS = ["traceparent", "tracestate"]
_global_tracer_ = None
_is_otel_imported = False
otel_import_error_traceback: Optional[str] = None
try:
from opentelemetry.context.context import Context
from opentelemetry.sdk.environment_variables import \
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import (SpanKind, Status, StatusCode, Tracer,
get_current_span, set_tracer_provider)
from opentelemetry.trace.propagation.tracecontext import \
TraceContextTextMapPropagator
_is_otel_imported = True
except ImportError:
import traceback
otel_import_error_traceback = traceback.format_exc()
class Context: # type: ignore
pass
class BaseSpanAttributes: # type: ignore
pass
class SpanKind: # type: ignore
pass
class Tracer: # type: ignore
pass
def is_otel_available() -> bool:
return _is_otel_imported
def init_tracer(instrumenting_module_name: str,
otlp_traces_endpoint: str) -> Optional[Tracer]:
if not is_otel_available():
raise ValueError(
"OpenTelemetry is not available. Unable to initialize "
"a tracer. Ensure OpenTelemetry packages are installed. "
f"Original error:\n{otel_import_error_traceback}")
trace_provider = TracerProvider()
span_exporter = get_span_exporter(otlp_traces_endpoint)
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
set_tracer_provider(trace_provider)
tracer = trace_provider.get_tracer(instrumenting_module_name)
set_global_otlp_tracer(tracer)
return tracer
def get_span_exporter(endpoint):
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
if protocol == "grpc":
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import \
OTLPSpanExporter
elif protocol == "http/protobuf":
from opentelemetry.exporter.otlp.proto.http.trace_exporter import \
OTLPSpanExporter # type: ignore
else:
raise ValueError(
f"Unsupported OTLP protocol '{protocol}' is configured")
return OTLPSpanExporter(endpoint=endpoint)
def extract_trace_context(
headers: Optional[Mapping[str, str]]) -> Optional[Context]:
if is_otel_available():
headers = headers or {}
return TraceContextTextMapPropagator().extract(headers)
else:
return None
def extract_trace_headers(
headers: Mapping[str, str]) -> Optional[Mapping[str, str]]:
if is_tracing_enabled():
# Return only recognized trace headers with normalized lowercase keys
lower_map = {k.lower(): v for k, v in headers.items()}
return {h: lower_map[h] for h in TRACE_HEADERS if h in lower_map}
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
def inject_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
if is_tracing_enabled():
trace_headers = extract_trace_headers(headers) if not headers else {}
TraceContextTextMapPropagator().inject(trace_headers)
return trace_headers
return None
def global_otlp_tracer() -> Tracer:
"""Get the global OTLP instance in the current process."""
return _global_tracer_
def set_global_otlp_tracer(tracer: Tracer):
"""Set the global OTLP Tracer instance in the current process."""
global _global_tracer_
assert _global_tracer_ is None
_global_tracer_ = tracer
def is_tracing_enabled() -> bool:
return _global_tracer_ is not None
class SpanAttributes(StrEnum):
"""Span attributes for LLM tracing following GenAI semantic conventions."""
# Token usage attributes
GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens"
GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens"
# Request attributes
GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
GEN_AI_REQUEST_TOP_K = "gen_ai.request.top_k"
GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
GEN_AI_REQUEST_ID = "gen_ai.request.id"
GEN_AI_REQUEST_N = "gen_ai.request.n"
# Latency attributes
GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" # nosec B105
GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e"
GEN_AI_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
GEN_AI_LATENCY_KV_CACHE_TRANSFER_TIME = "gen_ai.latency.kv_cache_transfer_time"
# Response attributes
GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons"
class SpanEvents(StrEnum):
"""Span events for LLM tracing."""
KV_CACHE_TRANSFER_START = "kv_cache_transfer_start"
KV_CACHE_TRANSFER_END = "kv_cache_transfer_end"
CTX_SERVER_SELECTED = "ctx_server.selected"
GEN_SERVER_SELECTED = "gen_server.selected"
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
lower_keys = {k.lower() for k in headers.keys()}
return any(h in lower_keys for h in TRACE_HEADERS)
def add_event(name: str,
attributes: Optional[Mapping[str, object]] = None,
timestamp: typing.Optional[int] = None) -> None:
"""Add an event to the current span if tracing is available."""
if not is_tracing_enabled():
return
get_current_span().add_event(name, attributes, timestamp)
@run_once
def log_tracing_disabled_warning() -> None:
logger.warning(
"Received a request with trace context but tracing is disabled")
@run_once
def insufficient_request_metrics_warning() -> None:
logger.warning(
"Insufficient request metrics available; trace generation aborted.")
def trace_span(name: str = None):
def decorator(func):
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
span_name = name if name is not None else func.__name__
if global_otlp_tracer() is None:
return await func(*args, **kwargs)
trace_headers = None
for arg in list(args) + list(kwargs.values()):
if hasattr(arg, 'headers'):
trace_headers = extract_trace_context(arg.headers)
break
with global_otlp_tracer().start_as_current_span(
span_name, kind=SpanKind.SERVER,
context=trace_headers) as span:
try:
result = await func(*args, **kwargs)
span.set_status(Status(StatusCode.OK))
return result
except Exception as e:
span.record_exception(e)
span.set_status(
Status(StatusCode.ERROR, f"An error occurred: {e}"))
raise e
return async_wrapper
return decorator

View File

@ -6,6 +6,7 @@ class MetricNames(Enum):
TPOT = "tpot"
E2E = "e2e"
REQUEST_QUEUE_TIME = "request_queue_time"
ARRIVAL_TIMESTAMP = 'arrival_timestamp'
class RequestEventTiming(Enum):
@ -13,3 +14,6 @@ class RequestEventTiming(Enum):
FIRST_TOKEN_TIME = "first_token_time" # nosec: B105
FIRST_SCHEDULED_TIME = "first_scheduled_time"
LAST_TOKEN_TIME = "last_token_time" # nosec: B105
KV_CACHE_TRANSFER_START = "kv_cache_transfer_start"
KV_CACHE_TRANSFER_END = "kv_cache_transfer_end"
KV_CACHE_SIZE = "kv_cache_size"

View File

@ -6,6 +6,7 @@ import os
import signal
import traceback
from collections import deque
from collections.abc import Mapping
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Callable, Optional, Type, Union
@ -19,6 +20,7 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
# yapf: disable
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import tracing
from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig,
MetadataServerConfig, ServerRole,
get_ctx_gen_server_urls)
@ -57,6 +59,17 @@ class OpenAIDisaggServer:
self.gen_router = create_router(
config.gen_router_config, self.gen_servers, metadata_server_cfg, self.metadata_server)
self.conditional_disagg_config = config.conditional_disagg_config
self.otlp_cfg = config.otlp_config
try:
if self.otlp_cfg and self.otlp_cfg.otlp_traces_endpoint:
tracing.init_tracer("trt.llm", self.otlp_cfg.otlp_traces_endpoint)
logger.info(
f"Initialized OTLP tracer successfully, endpoint: {self.otlp_cfg.otlp_traces_endpoint}"
)
except Exception as e:
logger.error(f"Failed to initialize OTLP tracer: {e}")
self.perf_metrics_max_requests = config.perf_metrics_max_requests
if self.perf_metrics_max_requests > 0:
# record corresponding keys of context and generation servers for perf metrics
@ -284,7 +297,7 @@ class OpenAIDisaggServer:
return JSONResponse(content=return_metrics)
@tracing.trace_span("llm_request")
async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response:
if not await self.is_ready():
raise HTTPException(status_code=400, detail="Cluster is not ready")
@ -301,6 +314,7 @@ class OpenAIDisaggServer:
except Exception as e:
await self._handle_exception(e)
@tracing.trace_span("llm_request")
async def openai_chat_completion(self, req: ChatCompletionRequest, raw_request: Request) -> Response:
if not await self.is_ready():
raise HTTPException(status_code=400, detail="Cluster is not ready")
@ -319,7 +333,8 @@ class OpenAIDisaggServer:
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Internal server error {str(exception)}")
async def _send_context_request(self, ctx_server: str, ctx_req: Union[CompletionRequest, ChatCompletionRequest]):
async def _send_context_request(self, ctx_server: str, ctx_req: Union[CompletionRequest, ChatCompletionRequest],
trace_headers: Optional[Mapping[str, str]] = None):
ctx_req.disaggregated_params = DisaggregatedParams(request_type="context_only")
ctx_req.stream = False
@ -329,10 +344,10 @@ class OpenAIDisaggServer:
await self._increment_metric("ctx_total_requests")
try:
if isinstance(ctx_req, ChatCompletionRequest):
ctx_response = await self.send_chat_request(ctx_server, ctx_req)
ctx_response = await self.send_chat_request(ctx_server, ctx_req, trace_headers)
else:
assert isinstance(ctx_req, CompletionRequest)
ctx_response = await self.send_completion_request(ctx_server, ctx_req)
ctx_response = await self.send_completion_request(ctx_server, ctx_req, trace_headers)
finally:
await self.ctx_router.finish_request(ctx_req)
await self._increment_metric("ctx_completed_requests")
@ -352,9 +367,11 @@ class OpenAIDisaggServer:
gen_server = None
ctx_request_id = None
need_ctx = False
trace_headers = tracing.inject_trace_headers(raw_request.headers)
async def _merge_streaming_responses(ctx_response,
gen_req: Union[CompletionRequest, ChatCompletionRequest]):
gen_req: Union[CompletionRequest, ChatCompletionRequest],
trace_headers: Optional[Mapping[str, str]] = None):
try:
if ctx_response is not None and len(ctx_response.choices) != 1:
raise ValueError("Context server did not return a single choice. This is not expected")
@ -366,9 +383,9 @@ class OpenAIDisaggServer:
# Then yield the generation responses
await self._increment_metric("gen_total_requests")
if isinstance(gen_req, CompletionRequest):
gen_response = await self.send_completion_request(gen_server, gen_req)
gen_response = await self.send_completion_request(gen_server, gen_req, trace_headers)
elif isinstance(gen_req, ChatCompletionRequest):
gen_response = await self.send_chat_request(gen_server, gen_req)
gen_response = await self.send_chat_request(gen_server, gen_req, trace_headers)
else:
raise TypeError("Invalid request type: {type(gen_req).__name__}")
@ -413,8 +430,11 @@ class OpenAIDisaggServer:
if need_ctx:
ctx_req = copy.deepcopy(req)
ctx_server, _ = await self.ctx_router.get_next_server(ctx_req)
tracing.add_event(tracing.SpanEvents.CTX_SERVER_SELECTED, attributes={"server": str(ctx_server),})
# TODO: add ctx_server info into generation request for pre-registration
ctx_response = await self._send_context_request(ctx_server, ctx_req)
ctx_response = await self._send_context_request(ctx_server, ctx_req, trace_headers)
if ctx_response is not None and len(ctx_response.choices) != 1:
raise ValueError("Context server did not return a single choice. This is not expected")
@ -438,6 +458,7 @@ class OpenAIDisaggServer:
if gen_server is None:
gen_server, _ = await self.gen_router.get_next_server(req)
logger.debug("Sending request to gen server: %s", gen_server)
tracing.add_event(tracing.SpanEvents.GEN_SERVER_SELECTED,attributes={"server": str(gen_server),})
if not req.stream:
try:
@ -448,10 +469,10 @@ class OpenAIDisaggServer:
else:
await self._increment_metric("gen_total_requests")
if isinstance(req, CompletionRequest):
gen_response = await self.send_completion_request(gen_server, req)
gen_response = await self.send_completion_request(gen_server, req, trace_headers)
else:
assert isinstance(req, ChatCompletionRequest)
gen_response = await self.send_chat_request(gen_server, req)
gen_response = await self.send_chat_request(gen_server, req, trace_headers)
await self._increment_metric("gen_completed_requests")
if need_ctx and self.perf_metrics_keys is not None:
raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds()
@ -465,7 +486,7 @@ class OpenAIDisaggServer:
else:
# Return a streaming response that combines both context and generation responses
return StreamingResponse(
_merge_streaming_responses(ctx_response, req),
_merge_streaming_responses(ctx_response, req, trace_headers),
media_type="text/event-stream"
)
except:
@ -482,8 +503,15 @@ class OpenAIDisaggServer:
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
await uvicorn.Server(config).serve()
async def create_generator(self, url: str, request: Union[CompletionRequest, ChatCompletionRequest], end_point: str):
async with self.session.post(url + end_point, json=request.model_dump(exclude_unset=True)) as response:
async def create_generator(self, url: str, request: Union[CompletionRequest, ChatCompletionRequest],
end_point: str, trace_headers: Optional[Mapping[str, str]] = None):
# Prepare headers
headers = {"Content-Type": "application/json"}
if trace_headers:
headers.update(trace_headers)
async with self.session.post(url + end_point, json=request.model_dump(exclude_unset=True),
headers=headers) as response:
content_type = response.headers.get("Content-Type", "")
if "text/event-stream" in content_type:
if not request.stream:
@ -498,26 +526,33 @@ class OpenAIDisaggServer:
logger.error(f"Unexpected error in stream: {e}")
raise
async def create_completion_generator(self, url: str, request: CompletionRequest):
async for chunk in self.create_generator(url, request, "/v1/completions"):
async def create_completion_generator(self, url: str, request: CompletionRequest,
trace_headers: Optional[Mapping[str, str]] = None):
async for chunk in self.create_generator(url, request, "/v1/completions", trace_headers):
yield chunk
async def create_chat_generator(self, url: str, request: ChatCompletionRequest):
async for chunk in self.create_generator(url, request, "/v1/chat/completions"):
async def create_chat_generator(self, url: str, request: ChatCompletionRequest,
trace_headers: Optional[Mapping[str, str]] = None):
async for chunk in self.create_generator(url, request, "/v1/chat/completions", trace_headers):
yield chunk
async def send_request(self, url: str,
request: Union[CompletionRequest, ChatCompletionRequest],
endpoint: str,
response_type: Type[Union[CompletionResponse, ChatCompletionResponse]],
create_generator: Callable) -> Union[CompletionResponse, ChatCompletionResponse, StreamingResponse]:
create_generator: Callable,
trace_headers: Optional[Mapping[str, str]] = None) -> Union[CompletionResponse, ChatCompletionResponse, StreamingResponse]:
for attempt in range(self.max_retries + 1):
try:
headers = {"Content-Type": "application/json"}
if trace_headers:
headers.update(trace_headers)
if request.stream:
response_generator = create_generator(url, request)
response_generator = create_generator(url, request, headers)
return StreamingResponse(content=response_generator, media_type="text/event-stream")
else:
async with self.session.post(url + endpoint, json=request.model_dump(exclude_unset=True)) as response:
async with self.session.post(url + endpoint, json=request.model_dump(exclude_unset=True),
headers=headers) as response:
content_type = response.headers.get("Content-Type", "")
if "text/event-stream" in content_type:
raise ValueError("Received an event-stream although request stream was False")
@ -537,12 +572,13 @@ class OpenAIDisaggServer:
logger.error(f"Error encountered while processing request to {url+endpoint}: {e}")
raise
async def send_completion_request(self, url: str, request: CompletionRequest,
trace_headers: Optional[Mapping[str, str]] = None) -> Union[CompletionResponse, StreamingResponse]:
return await self.send_request(url, request, "/v1/completions", CompletionResponse, self.create_completion_generator, trace_headers)
async def send_completion_request(self, url: str, request: CompletionRequest) -> Union[CompletionResponse, StreamingResponse]:
return await self.send_request(url, request, "/v1/completions", CompletionResponse, self.create_completion_generator)
async def send_chat_request(self, url: str, request: ChatCompletionRequest) -> ChatCompletionResponse:
return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator)
async def send_chat_request(self, url: str, request: ChatCompletionRequest,
trace_headers: Optional[Mapping[str, str]] = None) -> ChatCompletionResponse:
return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator, trace_headers)
async def set_steady_clock_offsets(self, session: aiohttp.ClientSession):
STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset"

View File

@ -28,7 +28,7 @@ from tensorrt_llm.inputs.data import TokensPrompt
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.inputs.utils import ConversationMessage, apply_chat_template
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import MultimodalEncoder
from tensorrt_llm.llmapi import MultimodalEncoder, tracing
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
MetadataServerConfig, ServerRole)
from tensorrt_llm.llmapi.llm import RequestOutput
@ -541,6 +541,8 @@ class OpenAIServer:
postproc_args=postproc_args,
)
trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers))
promise = self.llm.generate_async(
inputs=prompt,
sampling_params=sampling_params,
@ -549,6 +551,7 @@ class OpenAIServer:
lora_request=request.lora_request,
disaggregated_params=disaggregated_params,
cache_salt=request.cache_salt,
trace_headers=trace_headers,
)
asyncio.create_task(self.await_disconnected(raw_request, promise))
if not self.postproc_worker_enabled:
@ -763,6 +766,7 @@ class OpenAIServer:
if request.stream else completion_response_post_processor,
postproc_args=postproc_args,
)
trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers))
prompt = prompt_inputs(prompt)
if prompt.get("prompt") is not None:
@ -777,7 +781,8 @@ class OpenAIServer:
_postproc_params=postproc_params,
streaming=request.stream,
lora_request=request.lora_request,
disaggregated_params=disaggregated_params
disaggregated_params=disaggregated_params,
trace_headers=trace_headers
)
asyncio.create_task(self.await_disconnected(raw_request, promise))
if not self.postproc_worker_enabled:

View File

@ -3,6 +3,7 @@ import copy
import inspect
import os
import pathlib
from collections.abc import Mapping
from dataclasses import _HAS_DEFAULT_FACTORY_CLASS, dataclass, fields
from pprint import pprint
from types import MethodType, NoneType

View File

@ -187,6 +187,10 @@ methods:
annotation: Optional[tensorrt_llm.llmapi.llm_args.SparseAttentionConfig]
default: null
status: prototype
otlp_traces_endpoint:
annotation: Optional[str]
default: null
status: prototype
ray_worker_extension_cls:
annotation: Optional[str]
default: null
@ -222,6 +226,10 @@ methods:
cache_salt:
annotation: Optional[str]
default: null
trace_headers:
annotation: Optional[Mapping[str, str]]
default: null
status: prototype
return_annotation: tensorrt_llm.llmapi.llm.RequestOutput
get_kv_cache_events:
parameters:

View File

@ -20,4 +20,13 @@ methods:
annotation: Optional[dict[str, float]]
default: None
return_annotation: None
do_tracing:
parameters:
output:
annotation: tensorrt_llm.executor.result.CompletionOutput
default: inspect._empty
req_perf_metrics_dict:
annotation: Optional[dict[str, float]]
default: None
return_annotation: None
properties: {}