mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Add opentelemetry tracing (#5897)
Signed-off-by: Zhang Haotong <zhanghaotong.zht@antgroup.com> Signed-off-by: zhanghaotong <zhanghaotong.zht@antgroup.com> Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Co-authored-by: Zhang Haotong <zhanghaotong.zht@alibaba-inc.com> Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
This commit is contained in:
parent
ce0d76135d
commit
1026069a2b
85
examples/opentelemetry/README.md
Normal file
85
examples/opentelemetry/README.md
Normal file
@ -0,0 +1,85 @@
|
||||
# OpenTelemetry Integration Guide
|
||||
|
||||
This guide explains how to setup OpenTelemetry tracing in TensorRT-LLM to monitor and debug your LLM inference services.
|
||||
|
||||
## Install OpenTelemetry
|
||||
|
||||
Install the required OpenTelemetry packages:
|
||||
|
||||
```bash
|
||||
pip install \
|
||||
'opentelemetry-sdk' \
|
||||
'opentelemetry-api' \
|
||||
'opentelemetry-exporter-otlp' \
|
||||
'opentelemetry-semantic-conventions-ai'
|
||||
```
|
||||
|
||||
## Start Jaeger
|
||||
|
||||
You can start Jaeger with Docker:
|
||||
|
||||
```bash
|
||||
docker run --rm --name jaeger \
|
||||
-e COLLECTOR_ZIPKIN_HOST_PORT=:9411 \
|
||||
-p 6831:6831/udp \
|
||||
-p 6832:6832/udp \
|
||||
-p 5778:5778 \
|
||||
-p 16686:16686 \
|
||||
-p 4317:4317 \
|
||||
-p 4318:4318 \
|
||||
-p 14250:14250 \
|
||||
-p 14268:14268 \
|
||||
-p 14269:14269 \
|
||||
-p 9411:9411 \
|
||||
jaegertracing/all-in-one:1.57.0
|
||||
```
|
||||
|
||||
Or run the jaeger-all-in-one(.exe) executable from [the binary distribution archives](https://www.jaegertracing.io/download/):
|
||||
|
||||
```bash
|
||||
jaeger-all-in-one --collector.zipkin.host-port=:9411
|
||||
```
|
||||
|
||||
## Setup environment variables and run TensorRT-LLM
|
||||
|
||||
Set up the environment variables:
|
||||
|
||||
```bash
|
||||
export JAEGER_IP=$(docker inspect --format '{{ .NetworkSettings.IPAddress }}' jaeger)
|
||||
export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=grpc
|
||||
export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=grpc://$JAEGER_IP:4317
|
||||
export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true
|
||||
export OTEL_SERVICE_NAME="trt-server"
|
||||
```
|
||||
|
||||
Then run TensorRT-LLM with OpenTelemetry, and make sure to set `return_perf_metrics` to true in the model configuration:
|
||||
|
||||
```bash
|
||||
trtllm-serve models/Qwen3-8B/ --otlp_traces_endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
|
||||
```
|
||||
|
||||
## Send requests and find traces in Jaeger
|
||||
|
||||
You can send a request to the server and view the traces in [Jaeger UI](http://localhost:16686/).
|
||||
The traces should be visible under the service name "trt-server".
|
||||
|
||||
## Configuration for Disaggregated Serving
|
||||
|
||||
For disaggregated serving scenarios, the configuration for ctx server and gen server remains the same as the standalone model. For the proxy, you can configure it as follows:
|
||||
|
||||
```yaml
|
||||
# disagg_config.yaml
|
||||
hostname: 127.0.0.1
|
||||
port: 8000
|
||||
backend: pytorch
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
urls:
|
||||
- "127.0.0.1:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
urls:
|
||||
- "127.0.0.1:8002"
|
||||
otlp_config:
|
||||
otlp_traces_endpoint: "grpc://0.0.0.0:4317"
|
||||
```
|
||||
@ -28,13 +28,14 @@ from contextlib import contextmanager
|
||||
from enum import EnumMeta
|
||||
from functools import lru_cache, partial, wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import nvtx
|
||||
from mpi4py import MPI
|
||||
from mpi4py.util import pkl5
|
||||
from packaging import version
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
# isort: off
|
||||
import torch
|
||||
@ -1155,6 +1156,21 @@ def set_prometheus_multiproc_dir() -> object:
|
||||
f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
# From: https://stackoverflow.com/a/4104188/2749989
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||
wrapper.has_run = True # type: ignore[attr-defined]
|
||||
return f(*args, **kwargs)
|
||||
|
||||
wrapper.has_run = False # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
||||
|
||||
TORCH_PYBIND11_ABI = None
|
||||
|
||||
|
||||
|
||||
@ -91,6 +91,7 @@ def get_llm_args(model: str,
|
||||
trust_remote_code: bool = False,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
fail_fast_on_attention_window_too_large: bool = False,
|
||||
otlp_traces_endpoint: Optional[str] = None,
|
||||
enable_chunked_prefill: bool = False,
|
||||
**llm_args_extra_dict: Any):
|
||||
|
||||
@ -134,6 +135,7 @@ def get_llm_args(model: str,
|
||||
"reasoning_parser": reasoning_parser,
|
||||
"fail_fast_on_attention_window_too_large":
|
||||
fail_fast_on_attention_window_too_large,
|
||||
"otlp_traces_endpoint": otlp_traces_endpoint,
|
||||
"enable_chunked_prefill": enable_chunked_prefill,
|
||||
}
|
||||
|
||||
@ -322,6 +324,10 @@ class ChoiceWithAlias(click.Choice):
|
||||
help=
|
||||
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
|
||||
)
|
||||
@click.option("--otlp_traces_endpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Target URL to which OpenTelemetry traces will be sent.")
|
||||
@click.option("--disagg_cluster_uri",
|
||||
type=str,
|
||||
default=None,
|
||||
@ -344,8 +350,8 @@ def serve(
|
||||
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
|
||||
metadata_server_config_file: Optional[str], server_role: Optional[str],
|
||||
fail_fast_on_attention_window_too_large: bool,
|
||||
enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str],
|
||||
media_io_kwargs: Optional[str]):
|
||||
otlp_traces_endpoint: Optional[str], enable_chunked_prefill: bool,
|
||||
disagg_cluster_uri: Optional[str], media_io_kwargs: Optional[str]):
|
||||
"""Running an OpenAI API compatible server
|
||||
|
||||
MODEL: model name | HF checkpoint path | TensorRT engine path
|
||||
@ -371,6 +377,7 @@ def serve(
|
||||
reasoning_parser=reasoning_parser,
|
||||
fail_fast_on_attention_window_too_large=
|
||||
fail_fast_on_attention_window_too_large,
|
||||
otlp_traces_endpoint=otlp_traces_endpoint,
|
||||
enable_chunked_prefill=enable_chunked_prefill)
|
||||
|
||||
llm_args_extra_dict = {}
|
||||
|
||||
@ -886,7 +886,15 @@ def _get_metrics_dict(
|
||||
req_perf_metrics.timing_metrics.first_scheduled_time.
|
||||
total_seconds(),
|
||||
RequestEventTiming.LAST_TOKEN_TIME:
|
||||
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
|
||||
req_perf_metrics.timing_metrics.last_token_time.total_seconds(),
|
||||
RequestEventTiming.KV_CACHE_TRANSFER_START:
|
||||
req_perf_metrics.timing_metrics.kv_cache_transfer_start.
|
||||
total_seconds(),
|
||||
RequestEventTiming.KV_CACHE_TRANSFER_END:
|
||||
req_perf_metrics.timing_metrics.kv_cache_transfer_end.
|
||||
total_seconds(),
|
||||
RequestEventTiming.KV_CACHE_SIZE:
|
||||
req_perf_metrics.timing_metrics.kv_cache_size,
|
||||
}
|
||||
return metrics_dict
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import platform
|
||||
import signal
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import (TYPE_CHECKING, AsyncIterable, Dict, Generator, List,
|
||||
@ -123,6 +124,7 @@ class GenerationExecutor(ABC):
|
||||
streaming: bool = False,
|
||||
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
postproc_params: Optional[PostprocParams] = None,
|
||||
multimodal_params: Optional[MultimodalParams] = None,
|
||||
scheduling_params: Optional[SchedulingParams] = None,
|
||||
@ -150,6 +152,7 @@ class GenerationExecutor(ABC):
|
||||
streaming=streaming,
|
||||
kv_cache_retention_config=kv_cache_retention_config,
|
||||
disaggregated_params=disaggregated_params,
|
||||
trace_headers=trace_headers,
|
||||
multimodal_params=multimodal_params,
|
||||
scheduling_params=scheduling_params,
|
||||
cache_salt_id=cache_salt_id,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@ -94,6 +95,7 @@ class GenerationRequest:
|
||||
streaming: bool = False,
|
||||
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
postproc_params: Optional[PostprocParams] = None,
|
||||
multimodal_params: Optional[MultimodalParams] = None,
|
||||
scheduling_params: Optional[SchedulingParams] = None,
|
||||
@ -123,6 +125,7 @@ class GenerationRequest:
|
||||
self.kv_cache_retention_config = kv_cache_retention_config
|
||||
self.id: Optional[int] = None
|
||||
self.disaggregated_params = disaggregated_params
|
||||
self.trace_headers = trace_headers
|
||||
self.scheduling_params = scheduling_params
|
||||
self.cache_salt_id = cache_salt_id
|
||||
self.arrival_time = arrival_time
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from queue import Empty, Queue
|
||||
@ -11,6 +12,8 @@ from weakref import WeakMethod
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tensorrt_llm.llmapi import tracing
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ModuleNotFoundError:
|
||||
@ -268,6 +271,7 @@ class GenerationResultBase:
|
||||
self.avg_decoded_tokens_per_iter: Optional[float] = None
|
||||
self._done = False
|
||||
self.metrics_dict = {}
|
||||
self.trace_headers: Optional[dict[str, str]] = None
|
||||
|
||||
if ray_queue is not None:
|
||||
if has_event_loop():
|
||||
@ -436,6 +440,7 @@ class GenerationResultBase:
|
||||
raise ValueError(
|
||||
f"Unknown finish reason: {finish_reasons[src_idx]}")
|
||||
self.record_stats(output, req_perf_metrics_dict)
|
||||
self.do_tracing(output, req_perf_metrics_dict)
|
||||
|
||||
@print_traceback_on_error
|
||||
@nvtx_range_debug("handle_response",
|
||||
@ -472,7 +477,7 @@ class GenerationResultBase:
|
||||
self._outputs[0].disaggregated_params = disaggregated_params
|
||||
|
||||
if response.metrics:
|
||||
self.metrics_dict = response.metrics
|
||||
self.metrics_dict.update(response.metrics)
|
||||
|
||||
if response.error:
|
||||
if self._background_error_handler is not None and (
|
||||
@ -570,7 +575,110 @@ class GenerationResultBase:
|
||||
stats, len(output.token_ids), self.sampling_params.n > 1)
|
||||
if processed_metrics_stat:
|
||||
metrics_stats.update(processed_metrics_stat)
|
||||
self.metrics_dict = metrics_stats
|
||||
self.metrics_dict.update(metrics_stats)
|
||||
|
||||
def do_tracing(
|
||||
self,
|
||||
output: CompletionOutput,
|
||||
req_perf_metrics_dict: Optional[dict[str, float]] = None,
|
||||
) -> None:
|
||||
"""Perform distributed tracing for the generation request.
|
||||
|
||||
Args:
|
||||
output (CompletionOutput): The output of the generation result.
|
||||
req_perf_metrics_dict (Optional[dict[str, float]]): Request performance metrics. Defaults to None.
|
||||
"""
|
||||
if not tracing.global_otlp_tracer():
|
||||
return
|
||||
|
||||
metrics_dict = self.metrics_dict
|
||||
if not metrics_dict or not req_perf_metrics_dict:
|
||||
# Insufficient request metrics available; trace generation aborted.
|
||||
tracing.insufficient_request_metrics_warning()
|
||||
return
|
||||
|
||||
trace_context = tracing.extract_trace_context(self.trace_headers)
|
||||
sampling_params = self.sampling_params
|
||||
|
||||
# Since arrival_time and other timing metrics are based on different time origins,
|
||||
# we need to apply corrections to align them with absolute timestamps
|
||||
time_correction = time.time() - time.monotonic()
|
||||
arrival_time = req_perf_metrics_dict.get(
|
||||
RequestEventTiming.ARRIVAL_TIME, 0)
|
||||
|
||||
with tracing.global_otlp_tracer().start_as_current_span(
|
||||
"llm_request",
|
||||
kind=tracing.SpanKind.SERVER,
|
||||
context=trace_context,
|
||||
start_time=int((arrival_time + time_correction) * 1e9),
|
||||
) as span:
|
||||
|
||||
def safe_set_attr(span, attr, value):
|
||||
if value is not None:
|
||||
span.set_attribute(attr, value)
|
||||
|
||||
safe_set_attr(span,
|
||||
tracing.SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
|
||||
sampling_params.temperature)
|
||||
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_TOP_P,
|
||||
sampling_params.top_p)
|
||||
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_TOP_K,
|
||||
sampling_params.top_k)
|
||||
safe_set_attr(
|
||||
span,
|
||||
tracing.SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
|
||||
sampling_params.max_tokens,
|
||||
)
|
||||
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_N,
|
||||
sampling_params.n)
|
||||
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_REQUEST_ID,
|
||||
self.id)
|
||||
if prompt_token_ids := getattr(self, "prompt_token_ids", None):
|
||||
safe_set_attr(span,
|
||||
tracing.SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
|
||||
len(prompt_token_ids))
|
||||
safe_set_attr(span,
|
||||
tracing.SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
|
||||
output.length)
|
||||
safe_set_attr(
|
||||
span, tracing.SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
|
||||
metrics_dict.get(MetricNames.TTFT, -1))
|
||||
safe_set_attr(span, tracing.SpanAttributes.GEN_AI_LATENCY_E2E,
|
||||
metrics_dict.get(MetricNames.E2E, -1))
|
||||
safe_set_attr(span,
|
||||
tracing.SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
|
||||
metrics_dict.get(MetricNames.REQUEST_QUEUE_TIME, -1))
|
||||
safe_set_attr(
|
||||
span, tracing.SpanAttributes.GEN_AI_RESPONSE_FINISH_REASONS,
|
||||
json.dumps([output.finish_reason])
|
||||
if output.finish_reason else None)
|
||||
safe_set_attr(
|
||||
span,
|
||||
tracing.SpanAttributes.GEN_AI_LATENCY_KV_CACHE_TRANSFER_TIME,
|
||||
req_perf_metrics_dict.get(
|
||||
RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) -
|
||||
req_perf_metrics_dict.get(
|
||||
RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0))
|
||||
|
||||
if req_perf_metrics_dict.get(
|
||||
RequestEventTiming.KV_CACHE_TRANSFER_START,
|
||||
0) and req_perf_metrics_dict.get(
|
||||
RequestEventTiming.KV_CACHE_TRANSFER_END, 0):
|
||||
tracing.add_event(
|
||||
tracing.SpanEvents.KV_CACHE_TRANSFER_START,
|
||||
timestamp=int((req_perf_metrics_dict.get(
|
||||
RequestEventTiming.KV_CACHE_TRANSFER_START, 0.0) +
|
||||
time_correction) * 1e9))
|
||||
tracing.add_event(
|
||||
tracing.SpanEvents.KV_CACHE_TRANSFER_END,
|
||||
attributes={
|
||||
"kv_cache_size":
|
||||
req_perf_metrics_dict.get(
|
||||
RequestEventTiming.KV_CACHE_SIZE, 0)
|
||||
},
|
||||
timestamp=int((req_perf_metrics_dict.get(
|
||||
RequestEventTiming.KV_CACHE_TRANSFER_END, 0.0) +
|
||||
time_correction) * 1e9))
|
||||
|
||||
|
||||
class DetokenizedGenerationResultBase(GenerationResultBase):
|
||||
@ -688,6 +796,7 @@ class GenerationResult(GenerationResultBase):
|
||||
self.disaggregated_params = disaggregated_params
|
||||
# minimal sampling params needed for logprob calculation
|
||||
self._logprob_params = logprob_params
|
||||
self.trace_headers = generation_request.trace_headers
|
||||
|
||||
# for aborting the request
|
||||
self._executor: Optional[weakref.ReferenceType[
|
||||
|
||||
@ -43,6 +43,12 @@ class ConditionalDisaggConfig():
|
||||
max_local_prefill_length: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class OtlpConfig():
|
||||
otlp_traces_endpoint: Optional[
|
||||
str] = None # Target URL to which OpenTelemetry traces will be sent
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinimalInstances:
|
||||
context_servers: int = 1 # the minimal number of context servers
|
||||
@ -66,6 +72,7 @@ class DisaggServerConfig():
|
||||
ctx_router_config: Optional[RouterConfig] = None
|
||||
gen_router_config: Optional[RouterConfig] = None
|
||||
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
|
||||
otlp_config: Optional[OtlpConfig] = None
|
||||
max_retries: int = 1
|
||||
perf_metrics_max_requests: int = 0
|
||||
disagg_cluster_config: Optional[DisaggClusterConfig] = None
|
||||
@ -112,6 +119,7 @@ def extract_disagg_cfg(hostname: str = 'localhost',
|
||||
context_servers: Optional[dict] = None,
|
||||
generation_servers: Optional[dict] = None,
|
||||
conditional_disagg_config: Optional[dict] = None,
|
||||
otlp_config: Optional[dict] = None,
|
||||
disagg_cluster: Optional[dict] = None,
|
||||
**kwargs: Any) -> DisaggServerConfig:
|
||||
context_servers = context_servers or {}
|
||||
@ -149,10 +157,12 @@ def extract_disagg_cfg(hostname: str = 'localhost',
|
||||
conditional_disagg_config = ConditionalDisaggConfig(
|
||||
**conditional_disagg_config) if conditional_disagg_config else None
|
||||
|
||||
otlp_config = OtlpConfig(**otlp_config) if otlp_config else None
|
||||
|
||||
config = DisaggServerConfig(server_configs, hostname, port,
|
||||
ctx_router_config, gen_router_config,
|
||||
conditional_disagg_config, max_retries,
|
||||
perf_metrics_max_requests,
|
||||
conditional_disagg_config, otlp_config,
|
||||
max_retries, perf_metrics_max_requests,
|
||||
disagg_cluster_config)
|
||||
|
||||
return config
|
||||
|
||||
@ -6,6 +6,7 @@ import socket
|
||||
import tempfile
|
||||
import time
|
||||
import weakref
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Literal, Optional, Sequence, Union
|
||||
|
||||
@ -17,6 +18,8 @@ from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.inputs.data import TextPrompt
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams
|
||||
from tensorrt_llm.inputs.registry import DefaultInputProcessor
|
||||
from tensorrt_llm.llmapi import tracing
|
||||
from tensorrt_llm.metrics.enums import MetricNames
|
||||
|
||||
from .._utils import nvtx_range_debug
|
||||
from ..bindings import executor as tllm
|
||||
@ -230,6 +233,15 @@ class BaseLLM:
|
||||
self.mpi_session.shutdown()
|
||||
raise
|
||||
|
||||
try:
|
||||
if self.args.otlp_traces_endpoint:
|
||||
tracing.init_tracer("trt.llm", self.args.otlp_traces_endpoint)
|
||||
logger.info(
|
||||
f"Initialized OTLP tracer successfully, endpoint: {self.args.otlp_traces_endpoint}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OTLP tracer: {e}")
|
||||
|
||||
exception_handler.register(self, 'shutdown')
|
||||
atexit.register(LLM._shutdown_wrapper, weakref.ref(self))
|
||||
|
||||
@ -338,6 +350,7 @@ class BaseLLM:
|
||||
streaming: bool = False,
|
||||
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
_postproc_params: Optional[PostprocParams] = None,
|
||||
scheduling_params: Optional[SchedulingParams] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
@ -354,6 +367,7 @@ class BaseLLM:
|
||||
streaming (bool): Whether to use the streaming mode for the generation. Defaults to False.
|
||||
kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, optional): Configuration for the request's retention in the KV Cache. Defaults to None.
|
||||
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Disaggregated parameters. Defaults to None.
|
||||
trace_headers (Mapping[str, str], optional): Trace headers. Defaults to None.
|
||||
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, optional): Scheduling parameters. Defaults to None.
|
||||
cache_salt (str, optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None.
|
||||
Returns:
|
||||
@ -486,6 +500,7 @@ class BaseLLM:
|
||||
streaming=streaming,
|
||||
kv_cache_retention_config=kv_cache_retention_config,
|
||||
disaggregated_params=disaggregated_params,
|
||||
trace_headers=trace_headers,
|
||||
postproc_params=_postproc_params,
|
||||
multimodal_params=multimodal_params,
|
||||
scheduling_params=scheduling_params,
|
||||
@ -493,6 +508,10 @@ class BaseLLM:
|
||||
arrival_time=arrival_time,
|
||||
)
|
||||
|
||||
if sampling_params.return_perf_metrics:
|
||||
result.metrics_dict.update(
|
||||
{MetricNames.ARRIVAL_TIMESTAMP: time.time()})
|
||||
|
||||
return RequestOutput._from_generation_result(result, prompt,
|
||||
self.tokenizer)
|
||||
|
||||
|
||||
@ -1703,6 +1703,12 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
exclude=True,
|
||||
alias="_mpi_session")
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Target URL to which OpenTelemetry traces will be sent.",
|
||||
alias="otlp_traces_endpoint",
|
||||
status="prototype")
|
||||
|
||||
backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The backend to use for this LLM instance.",
|
||||
|
||||
227
tensorrt_llm/llmapi/tracing.py
Normal file
227
tensorrt_llm/llmapi/tracing.py
Normal file
@ -0,0 +1,227 @@
|
||||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
__all__ = [
|
||||
'SpanAttributes', 'SpanKind', 'contains_trace_headers',
|
||||
'extract_trace_context', 'get_span_exporter', 'global_otlp_tracer',
|
||||
'init_tracer', 'insufficient_request_metrics_warning', 'is_otel_available',
|
||||
'is_tracing_enabled', 'log_tracing_disabled_warning',
|
||||
'set_global_otlp_tracer', 'extract_trace_headers'
|
||||
]
|
||||
|
||||
import functools
|
||||
import os
|
||||
import typing
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
from strenum import StrEnum
|
||||
|
||||
from tensorrt_llm._utils import run_once
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0rc1/vllm/tracing.py#L11
|
||||
TRACE_HEADERS = ["traceparent", "tracestate"]
|
||||
|
||||
_global_tracer_ = None
|
||||
_is_otel_imported = False
|
||||
otel_import_error_traceback: Optional[str] = None
|
||||
|
||||
try:
|
||||
from opentelemetry.context.context import Context
|
||||
from opentelemetry.sdk.environment_variables import \
|
||||
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.trace import (SpanKind, Status, StatusCode, Tracer,
|
||||
get_current_span, set_tracer_provider)
|
||||
from opentelemetry.trace.propagation.tracecontext import \
|
||||
TraceContextTextMapPropagator
|
||||
|
||||
_is_otel_imported = True
|
||||
except ImportError:
|
||||
import traceback
|
||||
|
||||
otel_import_error_traceback = traceback.format_exc()
|
||||
|
||||
class Context: # type: ignore
|
||||
pass
|
||||
|
||||
class BaseSpanAttributes: # type: ignore
|
||||
pass
|
||||
|
||||
class SpanKind: # type: ignore
|
||||
pass
|
||||
|
||||
class Tracer: # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
def is_otel_available() -> bool:
|
||||
return _is_otel_imported
|
||||
|
||||
|
||||
def init_tracer(instrumenting_module_name: str,
|
||||
otlp_traces_endpoint: str) -> Optional[Tracer]:
|
||||
if not is_otel_available():
|
||||
raise ValueError(
|
||||
"OpenTelemetry is not available. Unable to initialize "
|
||||
"a tracer. Ensure OpenTelemetry packages are installed. "
|
||||
f"Original error:\n{otel_import_error_traceback}")
|
||||
trace_provider = TracerProvider()
|
||||
span_exporter = get_span_exporter(otlp_traces_endpoint)
|
||||
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
|
||||
set_tracer_provider(trace_provider)
|
||||
tracer = trace_provider.get_tracer(instrumenting_module_name)
|
||||
set_global_otlp_tracer(tracer)
|
||||
return tracer
|
||||
|
||||
|
||||
def get_span_exporter(endpoint):
|
||||
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
|
||||
if protocol == "grpc":
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import \
|
||||
OTLPSpanExporter
|
||||
elif protocol == "http/protobuf":
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import \
|
||||
OTLPSpanExporter # type: ignore
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported OTLP protocol '{protocol}' is configured")
|
||||
return OTLPSpanExporter(endpoint=endpoint)
|
||||
|
||||
|
||||
def extract_trace_context(
|
||||
headers: Optional[Mapping[str, str]]) -> Optional[Context]:
|
||||
if is_otel_available():
|
||||
headers = headers or {}
|
||||
return TraceContextTextMapPropagator().extract(headers)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def extract_trace_headers(
|
||||
headers: Mapping[str, str]) -> Optional[Mapping[str, str]]:
|
||||
if is_tracing_enabled():
|
||||
# Return only recognized trace headers with normalized lowercase keys
|
||||
lower_map = {k.lower(): v for k, v in headers.items()}
|
||||
return {h: lower_map[h] for h in TRACE_HEADERS if h in lower_map}
|
||||
if contains_trace_headers(headers):
|
||||
log_tracing_disabled_warning()
|
||||
return None
|
||||
|
||||
|
||||
def inject_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
|
||||
if is_tracing_enabled():
|
||||
trace_headers = extract_trace_headers(headers) if not headers else {}
|
||||
TraceContextTextMapPropagator().inject(trace_headers)
|
||||
return trace_headers
|
||||
return None
|
||||
|
||||
|
||||
def global_otlp_tracer() -> Tracer:
|
||||
"""Get the global OTLP instance in the current process."""
|
||||
return _global_tracer_
|
||||
|
||||
|
||||
def set_global_otlp_tracer(tracer: Tracer):
|
||||
"""Set the global OTLP Tracer instance in the current process."""
|
||||
global _global_tracer_
|
||||
assert _global_tracer_ is None
|
||||
_global_tracer_ = tracer
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
return _global_tracer_ is not None
|
||||
|
||||
|
||||
class SpanAttributes(StrEnum):
|
||||
"""Span attributes for LLM tracing following GenAI semantic conventions."""
|
||||
|
||||
# Token usage attributes
|
||||
GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens"
|
||||
GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens"
|
||||
|
||||
# Request attributes
|
||||
GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
|
||||
GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
|
||||
GEN_AI_REQUEST_TOP_K = "gen_ai.request.top_k"
|
||||
GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
|
||||
GEN_AI_REQUEST_ID = "gen_ai.request.id"
|
||||
GEN_AI_REQUEST_N = "gen_ai.request.n"
|
||||
|
||||
# Latency attributes
|
||||
GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" # nosec B105
|
||||
GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e"
|
||||
GEN_AI_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
|
||||
GEN_AI_LATENCY_KV_CACHE_TRANSFER_TIME = "gen_ai.latency.kv_cache_transfer_time"
|
||||
|
||||
# Response attributes
|
||||
GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons"
|
||||
|
||||
|
||||
class SpanEvents(StrEnum):
|
||||
"""Span events for LLM tracing."""
|
||||
KV_CACHE_TRANSFER_START = "kv_cache_transfer_start"
|
||||
KV_CACHE_TRANSFER_END = "kv_cache_transfer_end"
|
||||
CTX_SERVER_SELECTED = "ctx_server.selected"
|
||||
GEN_SERVER_SELECTED = "gen_server.selected"
|
||||
|
||||
|
||||
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
|
||||
lower_keys = {k.lower() for k in headers.keys()}
|
||||
return any(h in lower_keys for h in TRACE_HEADERS)
|
||||
|
||||
|
||||
def add_event(name: str,
|
||||
attributes: Optional[Mapping[str, object]] = None,
|
||||
timestamp: typing.Optional[int] = None) -> None:
|
||||
"""Add an event to the current span if tracing is available."""
|
||||
if not is_tracing_enabled():
|
||||
return
|
||||
get_current_span().add_event(name, attributes, timestamp)
|
||||
|
||||
|
||||
@run_once
|
||||
def log_tracing_disabled_warning() -> None:
|
||||
logger.warning(
|
||||
"Received a request with trace context but tracing is disabled")
|
||||
|
||||
|
||||
@run_once
|
||||
def insufficient_request_metrics_warning() -> None:
|
||||
logger.warning(
|
||||
"Insufficient request metrics available; trace generation aborted.")
|
||||
|
||||
|
||||
def trace_span(name: str = None):
|
||||
|
||||
def decorator(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
span_name = name if name is not None else func.__name__
|
||||
if global_otlp_tracer() is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
trace_headers = None
|
||||
for arg in list(args) + list(kwargs.values()):
|
||||
if hasattr(arg, 'headers'):
|
||||
trace_headers = extract_trace_context(arg.headers)
|
||||
break
|
||||
|
||||
with global_otlp_tracer().start_as_current_span(
|
||||
span_name, kind=SpanKind.SERVER,
|
||||
context=trace_headers) as span:
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.record_exception(e)
|
||||
span.set_status(
|
||||
Status(StatusCode.ERROR, f"An error occurred: {e}"))
|
||||
raise e
|
||||
|
||||
return async_wrapper
|
||||
|
||||
return decorator
|
||||
@ -6,6 +6,7 @@ class MetricNames(Enum):
|
||||
TPOT = "tpot"
|
||||
E2E = "e2e"
|
||||
REQUEST_QUEUE_TIME = "request_queue_time"
|
||||
ARRIVAL_TIMESTAMP = 'arrival_timestamp'
|
||||
|
||||
|
||||
class RequestEventTiming(Enum):
|
||||
@ -13,3 +14,6 @@ class RequestEventTiming(Enum):
|
||||
FIRST_TOKEN_TIME = "first_token_time" # nosec: B105
|
||||
FIRST_SCHEDULED_TIME = "first_scheduled_time"
|
||||
LAST_TOKEN_TIME = "last_token_time" # nosec: B105
|
||||
KV_CACHE_TRANSFER_START = "kv_cache_transfer_start"
|
||||
KV_CACHE_TRANSFER_END = "kv_cache_transfer_end"
|
||||
KV_CACHE_SIZE = "kv_cache_size"
|
||||
|
||||
@ -6,6 +6,7 @@ import os
|
||||
import signal
|
||||
import traceback
|
||||
from collections import deque
|
||||
from collections.abc import Mapping
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Callable, Optional, Type, Union
|
||||
@ -19,6 +20,7 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
# yapf: disable
|
||||
from tensorrt_llm.executor import CppExecutorError
|
||||
from tensorrt_llm.llmapi import tracing
|
||||
from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig,
|
||||
MetadataServerConfig, ServerRole,
|
||||
get_ctx_gen_server_urls)
|
||||
@ -57,6 +59,17 @@ class OpenAIDisaggServer:
|
||||
self.gen_router = create_router(
|
||||
config.gen_router_config, self.gen_servers, metadata_server_cfg, self.metadata_server)
|
||||
self.conditional_disagg_config = config.conditional_disagg_config
|
||||
self.otlp_cfg = config.otlp_config
|
||||
|
||||
try:
|
||||
if self.otlp_cfg and self.otlp_cfg.otlp_traces_endpoint:
|
||||
tracing.init_tracer("trt.llm", self.otlp_cfg.otlp_traces_endpoint)
|
||||
logger.info(
|
||||
f"Initialized OTLP tracer successfully, endpoint: {self.otlp_cfg.otlp_traces_endpoint}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OTLP tracer: {e}")
|
||||
|
||||
self.perf_metrics_max_requests = config.perf_metrics_max_requests
|
||||
if self.perf_metrics_max_requests > 0:
|
||||
# record corresponding keys of context and generation servers for perf metrics
|
||||
@ -284,7 +297,7 @@ class OpenAIDisaggServer:
|
||||
|
||||
return JSONResponse(content=return_metrics)
|
||||
|
||||
|
||||
@tracing.trace_span("llm_request")
|
||||
async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response:
|
||||
if not await self.is_ready():
|
||||
raise HTTPException(status_code=400, detail="Cluster is not ready")
|
||||
@ -301,6 +314,7 @@ class OpenAIDisaggServer:
|
||||
except Exception as e:
|
||||
await self._handle_exception(e)
|
||||
|
||||
@tracing.trace_span("llm_request")
|
||||
async def openai_chat_completion(self, req: ChatCompletionRequest, raw_request: Request) -> Response:
|
||||
if not await self.is_ready():
|
||||
raise HTTPException(status_code=400, detail="Cluster is not ready")
|
||||
@ -319,7 +333,8 @@ class OpenAIDisaggServer:
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error {str(exception)}")
|
||||
|
||||
async def _send_context_request(self, ctx_server: str, ctx_req: Union[CompletionRequest, ChatCompletionRequest]):
|
||||
async def _send_context_request(self, ctx_server: str, ctx_req: Union[CompletionRequest, ChatCompletionRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None):
|
||||
|
||||
ctx_req.disaggregated_params = DisaggregatedParams(request_type="context_only")
|
||||
ctx_req.stream = False
|
||||
@ -329,10 +344,10 @@ class OpenAIDisaggServer:
|
||||
await self._increment_metric("ctx_total_requests")
|
||||
try:
|
||||
if isinstance(ctx_req, ChatCompletionRequest):
|
||||
ctx_response = await self.send_chat_request(ctx_server, ctx_req)
|
||||
ctx_response = await self.send_chat_request(ctx_server, ctx_req, trace_headers)
|
||||
else:
|
||||
assert isinstance(ctx_req, CompletionRequest)
|
||||
ctx_response = await self.send_completion_request(ctx_server, ctx_req)
|
||||
ctx_response = await self.send_completion_request(ctx_server, ctx_req, trace_headers)
|
||||
finally:
|
||||
await self.ctx_router.finish_request(ctx_req)
|
||||
await self._increment_metric("ctx_completed_requests")
|
||||
@ -352,9 +367,11 @@ class OpenAIDisaggServer:
|
||||
gen_server = None
|
||||
ctx_request_id = None
|
||||
need_ctx = False
|
||||
trace_headers = tracing.inject_trace_headers(raw_request.headers)
|
||||
|
||||
async def _merge_streaming_responses(ctx_response,
|
||||
gen_req: Union[CompletionRequest, ChatCompletionRequest]):
|
||||
gen_req: Union[CompletionRequest, ChatCompletionRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None):
|
||||
try:
|
||||
if ctx_response is not None and len(ctx_response.choices) != 1:
|
||||
raise ValueError("Context server did not return a single choice. This is not expected")
|
||||
@ -366,9 +383,9 @@ class OpenAIDisaggServer:
|
||||
# Then yield the generation responses
|
||||
await self._increment_metric("gen_total_requests")
|
||||
if isinstance(gen_req, CompletionRequest):
|
||||
gen_response = await self.send_completion_request(gen_server, gen_req)
|
||||
gen_response = await self.send_completion_request(gen_server, gen_req, trace_headers)
|
||||
elif isinstance(gen_req, ChatCompletionRequest):
|
||||
gen_response = await self.send_chat_request(gen_server, gen_req)
|
||||
gen_response = await self.send_chat_request(gen_server, gen_req, trace_headers)
|
||||
else:
|
||||
raise TypeError("Invalid request type: {type(gen_req).__name__}")
|
||||
|
||||
@ -413,8 +430,11 @@ class OpenAIDisaggServer:
|
||||
if need_ctx:
|
||||
ctx_req = copy.deepcopy(req)
|
||||
ctx_server, _ = await self.ctx_router.get_next_server(ctx_req)
|
||||
|
||||
tracing.add_event(tracing.SpanEvents.CTX_SERVER_SELECTED, attributes={"server": str(ctx_server),})
|
||||
|
||||
# TODO: add ctx_server info into generation request for pre-registration
|
||||
ctx_response = await self._send_context_request(ctx_server, ctx_req)
|
||||
ctx_response = await self._send_context_request(ctx_server, ctx_req, trace_headers)
|
||||
|
||||
if ctx_response is not None and len(ctx_response.choices) != 1:
|
||||
raise ValueError("Context server did not return a single choice. This is not expected")
|
||||
@ -438,6 +458,7 @@ class OpenAIDisaggServer:
|
||||
if gen_server is None:
|
||||
gen_server, _ = await self.gen_router.get_next_server(req)
|
||||
logger.debug("Sending request to gen server: %s", gen_server)
|
||||
tracing.add_event(tracing.SpanEvents.GEN_SERVER_SELECTED,attributes={"server": str(gen_server),})
|
||||
|
||||
if not req.stream:
|
||||
try:
|
||||
@ -448,10 +469,10 @@ class OpenAIDisaggServer:
|
||||
else:
|
||||
await self._increment_metric("gen_total_requests")
|
||||
if isinstance(req, CompletionRequest):
|
||||
gen_response = await self.send_completion_request(gen_server, req)
|
||||
gen_response = await self.send_completion_request(gen_server, req, trace_headers)
|
||||
else:
|
||||
assert isinstance(req, ChatCompletionRequest)
|
||||
gen_response = await self.send_chat_request(gen_server, req)
|
||||
gen_response = await self.send_chat_request(gen_server, req, trace_headers)
|
||||
await self._increment_metric("gen_completed_requests")
|
||||
if need_ctx and self.perf_metrics_keys is not None:
|
||||
raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds()
|
||||
@ -465,7 +486,7 @@ class OpenAIDisaggServer:
|
||||
else:
|
||||
# Return a streaming response that combines both context and generation responses
|
||||
return StreamingResponse(
|
||||
_merge_streaming_responses(ctx_response, req),
|
||||
_merge_streaming_responses(ctx_response, req, trace_headers),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
except:
|
||||
@ -482,8 +503,15 @@ class OpenAIDisaggServer:
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
await uvicorn.Server(config).serve()
|
||||
|
||||
async def create_generator(self, url: str, request: Union[CompletionRequest, ChatCompletionRequest], end_point: str):
|
||||
async with self.session.post(url + end_point, json=request.model_dump(exclude_unset=True)) as response:
|
||||
async def create_generator(self, url: str, request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
end_point: str, trace_headers: Optional[Mapping[str, str]] = None):
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if trace_headers:
|
||||
headers.update(trace_headers)
|
||||
|
||||
async with self.session.post(url + end_point, json=request.model_dump(exclude_unset=True),
|
||||
headers=headers) as response:
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if "text/event-stream" in content_type:
|
||||
if not request.stream:
|
||||
@ -498,26 +526,33 @@ class OpenAIDisaggServer:
|
||||
logger.error(f"Unexpected error in stream: {e}")
|
||||
raise
|
||||
|
||||
async def create_completion_generator(self, url: str, request: CompletionRequest):
|
||||
async for chunk in self.create_generator(url, request, "/v1/completions"):
|
||||
async def create_completion_generator(self, url: str, request: CompletionRequest,
|
||||
trace_headers: Optional[Mapping[str, str]] = None):
|
||||
async for chunk in self.create_generator(url, request, "/v1/completions", trace_headers):
|
||||
yield chunk
|
||||
|
||||
async def create_chat_generator(self, url: str, request: ChatCompletionRequest):
|
||||
async for chunk in self.create_generator(url, request, "/v1/chat/completions"):
|
||||
async def create_chat_generator(self, url: str, request: ChatCompletionRequest,
|
||||
trace_headers: Optional[Mapping[str, str]] = None):
|
||||
async for chunk in self.create_generator(url, request, "/v1/chat/completions", trace_headers):
|
||||
yield chunk
|
||||
|
||||
async def send_request(self, url: str,
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
endpoint: str,
|
||||
response_type: Type[Union[CompletionResponse, ChatCompletionResponse]],
|
||||
create_generator: Callable) -> Union[CompletionResponse, ChatCompletionResponse, StreamingResponse]:
|
||||
create_generator: Callable,
|
||||
trace_headers: Optional[Mapping[str, str]] = None) -> Union[CompletionResponse, ChatCompletionResponse, StreamingResponse]:
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if trace_headers:
|
||||
headers.update(trace_headers)
|
||||
if request.stream:
|
||||
response_generator = create_generator(url, request)
|
||||
response_generator = create_generator(url, request, headers)
|
||||
return StreamingResponse(content=response_generator, media_type="text/event-stream")
|
||||
else:
|
||||
async with self.session.post(url + endpoint, json=request.model_dump(exclude_unset=True)) as response:
|
||||
async with self.session.post(url + endpoint, json=request.model_dump(exclude_unset=True),
|
||||
headers=headers) as response:
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if "text/event-stream" in content_type:
|
||||
raise ValueError("Received an event-stream although request stream was False")
|
||||
@ -537,12 +572,13 @@ class OpenAIDisaggServer:
|
||||
logger.error(f"Error encountered while processing request to {url+endpoint}: {e}")
|
||||
raise
|
||||
|
||||
async def send_completion_request(self, url: str, request: CompletionRequest,
|
||||
trace_headers: Optional[Mapping[str, str]] = None) -> Union[CompletionResponse, StreamingResponse]:
|
||||
return await self.send_request(url, request, "/v1/completions", CompletionResponse, self.create_completion_generator, trace_headers)
|
||||
|
||||
async def send_completion_request(self, url: str, request: CompletionRequest) -> Union[CompletionResponse, StreamingResponse]:
|
||||
return await self.send_request(url, request, "/v1/completions", CompletionResponse, self.create_completion_generator)
|
||||
|
||||
async def send_chat_request(self, url: str, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator)
|
||||
async def send_chat_request(self, url: str, request: ChatCompletionRequest,
|
||||
trace_headers: Optional[Mapping[str, str]] = None) -> ChatCompletionResponse:
|
||||
return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator, trace_headers)
|
||||
|
||||
async def set_steady_clock_offsets(self, session: aiohttp.ClientSession):
|
||||
STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset"
|
||||
|
||||
@ -28,7 +28,7 @@ from tensorrt_llm.inputs.data import TokensPrompt
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
|
||||
from tensorrt_llm.inputs.utils import ConversationMessage, apply_chat_template
|
||||
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
|
||||
from tensorrt_llm.llmapi import MultimodalEncoder
|
||||
from tensorrt_llm.llmapi import MultimodalEncoder, tracing
|
||||
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
|
||||
MetadataServerConfig, ServerRole)
|
||||
from tensorrt_llm.llmapi.llm import RequestOutput
|
||||
@ -541,6 +541,8 @@ class OpenAIServer:
|
||||
postproc_args=postproc_args,
|
||||
)
|
||||
|
||||
trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers))
|
||||
|
||||
promise = self.llm.generate_async(
|
||||
inputs=prompt,
|
||||
sampling_params=sampling_params,
|
||||
@ -549,6 +551,7 @@ class OpenAIServer:
|
||||
lora_request=request.lora_request,
|
||||
disaggregated_params=disaggregated_params,
|
||||
cache_salt=request.cache_salt,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||
if not self.postproc_worker_enabled:
|
||||
@ -763,6 +766,7 @@ class OpenAIServer:
|
||||
if request.stream else completion_response_post_processor,
|
||||
postproc_args=postproc_args,
|
||||
)
|
||||
trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers))
|
||||
|
||||
prompt = prompt_inputs(prompt)
|
||||
if prompt.get("prompt") is not None:
|
||||
@ -777,7 +781,8 @@ class OpenAIServer:
|
||||
_postproc_params=postproc_params,
|
||||
streaming=request.stream,
|
||||
lora_request=request.lora_request,
|
||||
disaggregated_params=disaggregated_params
|
||||
disaggregated_params=disaggregated_params,
|
||||
trace_headers=trace_headers
|
||||
)
|
||||
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||
if not self.postproc_worker_enabled:
|
||||
|
||||
@ -3,6 +3,7 @@ import copy
|
||||
import inspect
|
||||
import os
|
||||
import pathlib
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import _HAS_DEFAULT_FACTORY_CLASS, dataclass, fields
|
||||
from pprint import pprint
|
||||
from types import MethodType, NoneType
|
||||
|
||||
@ -187,6 +187,10 @@ methods:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.SparseAttentionConfig]
|
||||
default: null
|
||||
status: prototype
|
||||
otlp_traces_endpoint:
|
||||
annotation: Optional[str]
|
||||
default: null
|
||||
status: prototype
|
||||
ray_worker_extension_cls:
|
||||
annotation: Optional[str]
|
||||
default: null
|
||||
@ -222,6 +226,10 @@ methods:
|
||||
cache_salt:
|
||||
annotation: Optional[str]
|
||||
default: null
|
||||
trace_headers:
|
||||
annotation: Optional[Mapping[str, str]]
|
||||
default: null
|
||||
status: prototype
|
||||
return_annotation: tensorrt_llm.llmapi.llm.RequestOutput
|
||||
get_kv_cache_events:
|
||||
parameters:
|
||||
|
||||
@ -20,4 +20,13 @@ methods:
|
||||
annotation: Optional[dict[str, float]]
|
||||
default: None
|
||||
return_annotation: None
|
||||
do_tracing:
|
||||
parameters:
|
||||
output:
|
||||
annotation: tensorrt_llm.executor.result.CompletionOutput
|
||||
default: inspect._empty
|
||||
req_perf_metrics_dict:
|
||||
annotation: Optional[dict[str, float]]
|
||||
default: None
|
||||
return_annotation: None
|
||||
properties: {}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user