mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Zhang Haotong <zhanghaotong.zht@antgroup.com> Signed-off-by: zhanghaotong <zhanghaotong.zht@antgroup.com> Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Co-authored-by: Zhang Haotong <zhanghaotong.zht@alibaba-inc.com> Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
228 lines
7.5 KiB
Python
228 lines
7.5 KiB
Python
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
__all__ = [
|
|
'SpanAttributes', 'SpanKind', 'contains_trace_headers',
|
|
'extract_trace_context', 'get_span_exporter', 'global_otlp_tracer',
|
|
'init_tracer', 'insufficient_request_metrics_warning', 'is_otel_available',
|
|
'is_tracing_enabled', 'log_tracing_disabled_warning',
|
|
'set_global_otlp_tracer', 'extract_trace_headers'
|
|
]
|
|
|
|
import functools
|
|
import os
|
|
import typing
|
|
from collections.abc import Mapping
|
|
from typing import Optional
|
|
|
|
from strenum import StrEnum
|
|
|
|
from tensorrt_llm._utils import run_once
|
|
from tensorrt_llm.logger import logger
|
|
|
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0rc1/vllm/tracing.py#L11
|
|
TRACE_HEADERS = ["traceparent", "tracestate"]
|
|
|
|
_global_tracer_ = None
|
|
_is_otel_imported = False
|
|
otel_import_error_traceback: Optional[str] = None
|
|
|
|
try:
|
|
from opentelemetry.context.context import Context
|
|
from opentelemetry.sdk.environment_variables import \
|
|
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL
|
|
from opentelemetry.sdk.trace import TracerProvider
|
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
from opentelemetry.trace import (SpanKind, Status, StatusCode, Tracer,
|
|
get_current_span, set_tracer_provider)
|
|
from opentelemetry.trace.propagation.tracecontext import \
|
|
TraceContextTextMapPropagator
|
|
|
|
_is_otel_imported = True
|
|
except ImportError:
|
|
import traceback
|
|
|
|
otel_import_error_traceback = traceback.format_exc()
|
|
|
|
class Context: # type: ignore
|
|
pass
|
|
|
|
class BaseSpanAttributes: # type: ignore
|
|
pass
|
|
|
|
class SpanKind: # type: ignore
|
|
pass
|
|
|
|
class Tracer: # type: ignore
|
|
pass
|
|
|
|
|
|
def is_otel_available() -> bool:
|
|
return _is_otel_imported
|
|
|
|
|
|
def init_tracer(instrumenting_module_name: str,
|
|
otlp_traces_endpoint: str) -> Optional[Tracer]:
|
|
if not is_otel_available():
|
|
raise ValueError(
|
|
"OpenTelemetry is not available. Unable to initialize "
|
|
"a tracer. Ensure OpenTelemetry packages are installed. "
|
|
f"Original error:\n{otel_import_error_traceback}")
|
|
trace_provider = TracerProvider()
|
|
span_exporter = get_span_exporter(otlp_traces_endpoint)
|
|
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
|
|
set_tracer_provider(trace_provider)
|
|
tracer = trace_provider.get_tracer(instrumenting_module_name)
|
|
set_global_otlp_tracer(tracer)
|
|
return tracer
|
|
|
|
|
|
def get_span_exporter(endpoint):
|
|
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
|
|
if protocol == "grpc":
|
|
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import \
|
|
OTLPSpanExporter
|
|
elif protocol == "http/protobuf":
|
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import \
|
|
OTLPSpanExporter # type: ignore
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported OTLP protocol '{protocol}' is configured")
|
|
return OTLPSpanExporter(endpoint=endpoint)
|
|
|
|
|
|
def extract_trace_context(
|
|
headers: Optional[Mapping[str, str]]) -> Optional[Context]:
|
|
if is_otel_available():
|
|
headers = headers or {}
|
|
return TraceContextTextMapPropagator().extract(headers)
|
|
else:
|
|
return None
|
|
|
|
|
|
def extract_trace_headers(
|
|
headers: Mapping[str, str]) -> Optional[Mapping[str, str]]:
|
|
if is_tracing_enabled():
|
|
# Return only recognized trace headers with normalized lowercase keys
|
|
lower_map = {k.lower(): v for k, v in headers.items()}
|
|
return {h: lower_map[h] for h in TRACE_HEADERS if h in lower_map}
|
|
if contains_trace_headers(headers):
|
|
log_tracing_disabled_warning()
|
|
return None
|
|
|
|
|
|
def inject_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
|
|
if is_tracing_enabled():
|
|
trace_headers = extract_trace_headers(headers) if not headers else {}
|
|
TraceContextTextMapPropagator().inject(trace_headers)
|
|
return trace_headers
|
|
return None
|
|
|
|
|
|
def global_otlp_tracer() -> Tracer:
|
|
"""Get the global OTLP instance in the current process."""
|
|
return _global_tracer_
|
|
|
|
|
|
def set_global_otlp_tracer(tracer: Tracer):
|
|
"""Set the global OTLP Tracer instance in the current process."""
|
|
global _global_tracer_
|
|
assert _global_tracer_ is None
|
|
_global_tracer_ = tracer
|
|
|
|
|
|
def is_tracing_enabled() -> bool:
|
|
return _global_tracer_ is not None
|
|
|
|
|
|
class SpanAttributes(StrEnum):
|
|
"""Span attributes for LLM tracing following GenAI semantic conventions."""
|
|
|
|
# Token usage attributes
|
|
GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens"
|
|
GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens"
|
|
|
|
# Request attributes
|
|
GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
|
|
GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
|
|
GEN_AI_REQUEST_TOP_K = "gen_ai.request.top_k"
|
|
GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
|
|
GEN_AI_REQUEST_ID = "gen_ai.request.id"
|
|
GEN_AI_REQUEST_N = "gen_ai.request.n"
|
|
|
|
# Latency attributes
|
|
GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" # nosec B105
|
|
GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e"
|
|
GEN_AI_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
|
|
GEN_AI_LATENCY_KV_CACHE_TRANSFER_TIME = "gen_ai.latency.kv_cache_transfer_time"
|
|
|
|
# Response attributes
|
|
GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons"
|
|
|
|
|
|
class SpanEvents(StrEnum):
|
|
"""Span events for LLM tracing."""
|
|
KV_CACHE_TRANSFER_START = "kv_cache_transfer_start"
|
|
KV_CACHE_TRANSFER_END = "kv_cache_transfer_end"
|
|
CTX_SERVER_SELECTED = "ctx_server.selected"
|
|
GEN_SERVER_SELECTED = "gen_server.selected"
|
|
|
|
|
|
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
|
|
lower_keys = {k.lower() for k in headers.keys()}
|
|
return any(h in lower_keys for h in TRACE_HEADERS)
|
|
|
|
|
|
def add_event(name: str,
|
|
attributes: Optional[Mapping[str, object]] = None,
|
|
timestamp: typing.Optional[int] = None) -> None:
|
|
"""Add an event to the current span if tracing is available."""
|
|
if not is_tracing_enabled():
|
|
return
|
|
get_current_span().add_event(name, attributes, timestamp)
|
|
|
|
|
|
@run_once
|
|
def log_tracing_disabled_warning() -> None:
|
|
logger.warning(
|
|
"Received a request with trace context but tracing is disabled")
|
|
|
|
|
|
@run_once
|
|
def insufficient_request_metrics_warning() -> None:
|
|
logger.warning(
|
|
"Insufficient request metrics available; trace generation aborted.")
|
|
|
|
|
|
def trace_span(name: str = None):
|
|
|
|
def decorator(func):
|
|
|
|
@functools.wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
span_name = name if name is not None else func.__name__
|
|
if global_otlp_tracer() is None:
|
|
return await func(*args, **kwargs)
|
|
|
|
trace_headers = None
|
|
for arg in list(args) + list(kwargs.values()):
|
|
if hasattr(arg, 'headers'):
|
|
trace_headers = extract_trace_context(arg.headers)
|
|
break
|
|
|
|
with global_otlp_tracer().start_as_current_span(
|
|
span_name, kind=SpanKind.SERVER,
|
|
context=trace_headers) as span:
|
|
try:
|
|
result = await func(*args, **kwargs)
|
|
span.set_status(Status(StatusCode.OK))
|
|
return result
|
|
except Exception as e:
|
|
span.record_exception(e)
|
|
span.set_status(
|
|
Status(StatusCode.ERROR, f"An error occurred: {e}"))
|
|
raise e
|
|
|
|
return async_wrapper
|
|
|
|
return decorator
|