[Feature] OTEL tracing during loading (#31162)

This commit is contained in:
emricksini-h
2026-02-06 01:59:28 +01:00
committed by GitHub
parent 91a07ff618
commit 325ab6b0a8
29 changed files with 873 additions and 280 deletions
+1 -1
View File
@@ -52,4 +52,4 @@ anthropic >= 0.71.0
model-hosting-container-standards >= 0.1.13, < 1.0.0
mcp
grpcio
grpcio-reflection
grpcio-reflection
+7
View File
@@ -1049,6 +1049,13 @@ setup(
"petit-kernel": ["petit-kernel"],
# Optional deps for Helion kernel development
"helion": ["helion"],
# Optional deps for OpenTelemetry tracing
"otel": [
"opentelemetry-sdk>=1.26.0",
"opentelemetry-api>=1.26.0",
"opentelemetry-exporter-otlp>=1.26.0",
"opentelemetry-semantic-conventions-ai>=0.4.1",
],
},
cmdclass=cmdclass,
package_data=package_data,
View File
+127
View File
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections.abc import Callable, Generator, Iterable
from concurrent import futures
from typing import Any, Literal
import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
ExportTraceServiceRequest,
ExportTraceServiceResponse,
)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
TraceServiceServicer,
add_TraceServiceServicer_to_server,
)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
FieldName = Literal[
"bool_value", "string_value", "int_value", "double_value", "array_value"
]
def decode_value(value: AnyValue):
"""Decode an OpenTelemetry AnyValue protobuf message to a Python value."""
field_decoders: dict[FieldName, Callable] = {
"bool_value": (lambda v: v.bool_value),
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
"double_value": (lambda v: v.double_value),
"array_value": (
lambda v: [decode_value(item) for item in v.array_value.values]
),
}
for field, decoder in field_decoders.items():
if value.HasField(field):
return decoder(value)
raise ValueError(f"Couldn't decode value: {value}")
def decode_attributes(attributes: Iterable[KeyValue]) -> dict[str, Any]:
"""Decode OpenTelemetry KeyValue attributes to a Python dictionary."""
return {kv.key: decode_value(kv.value) for kv in attributes}
class FakeTraceService(TraceServiceServicer):
"""A fake gRPC trace service for testing OpenTelemetry trace exports."""
def __init__(self):
self.requests: list[ExportTraceServiceRequest] = []
self.evt = threading.Event()
self._lock = threading.Lock()
def Export(self, request, context):
with self._lock:
self.requests.append(request)
self.evt.set()
return ExportTraceServiceResponse()
@property
def request(self) -> ExportTraceServiceRequest | None:
"""Returns the first request received (for backward compatibility)."""
with self._lock:
return self.requests[0] if self.requests else None
def get_all_spans(self) -> list[dict]:
"""Returns all spans from all received requests as decoded dicts."""
spans = []
with self._lock:
for request in self.requests:
for resource_span in request.resource_spans:
for scope_span in resource_span.scope_spans:
for span in scope_span.spans:
spans.append(
{
"name": span.name,
"attributes": decode_attributes(span.attributes),
"trace_id": span.trace_id.hex(),
"span_id": span.span_id.hex(),
"parent_span_id": span.parent_span_id.hex()
if span.parent_span_id
else None,
"start_time_unix_nano": span.start_time_unix_nano,
"end_time_unix_nano": span.end_time_unix_nano,
}
)
return spans
def wait_for_spans(self, count: int = 1, timeout: float = 10) -> bool:
"""Wait until at least `count` spans have been received."""
import time
deadline = time.time() + timeout
while time.time() < deadline:
if len(self.get_all_spans()) >= count:
return True
time.sleep(0.1)
return False
def clear(self):
"""Clear all received requests."""
with self._lock:
self.requests.clear()
self.evt.clear()
@pytest.fixture
def trace_service() -> Generator[FakeTraceService, None, None]:
"""Fixture to set up a fake gRPC trace service."""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
service = FakeTraceService()
add_TraceServiceServicer_to_server(service, server)
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
server.start()
yield service
server.stop(grace=None)
@pytest.fixture
def trace_server_address() -> str:
"""Returns the address of the fake trace server."""
return FAKE_TRACE_SERVER_ADDRESS
+87
View File
@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import pytest
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE
from tests.tracing.conftest import FAKE_TRACE_SERVER_ADDRESS, FakeTraceService
from vllm.tracing import init_tracer, instrument, is_otel_available
# Skip everything if OTel is missing
pytestmark = pytest.mark.skipif(not is_otel_available(), reason="OTel required")
class TestCoreInstrumentation:
"""Focuses on the @instrument decorator's ability to capture execution data."""
@pytest.fixture(autouse=True)
def setup_tracing(self, monkeypatch):
monkeypatch.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
init_tracer("test.core", FAKE_TRACE_SERVER_ADDRESS)
def test_decorator_captures_sync_and_async(self, trace_service: FakeTraceService):
"""Verify basic span creation for both sync and async functions."""
@instrument(span_name="sync_task")
def sync_task():
return True
@instrument(span_name="async_task")
async def async_task():
return True
sync_task()
asyncio.run(async_task())
assert trace_service.wait_for_spans(count=2)
span_names = [s["name"] for s in trace_service.get_all_spans()]
assert "sync_task" in span_names
assert "async_task" in span_names
def test_nested_spans_hierarchy(self, trace_service: FakeTraceService):
"""Verify that nested calls create a parent-child relationship."""
@instrument(span_name="child")
def child():
pass
@instrument(span_name="parent")
def parent():
child()
parent()
assert trace_service.wait_for_spans(count=2)
spans = trace_service.get_all_spans()
parent_span = next(s for s in spans if s["name"] == "parent")
child_span = next(s for s in spans if s["name"] == "child")
assert child_span["parent_span_id"] == parent_span["span_id"]
class TestInterProcessPropagation:
"""Test the propagation of trace context between processes."""
def test_pickup_external_context(self, monkeypatch, trace_service):
"""Test that vLLM attaches to an existing trace ID if in environment."""
monkeypatch.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
# Manually simulate an external parent trace ID
fake_trace_id = "4bf92f3577b34da6a3ce929d0e0e4736"
fake_parent_id = "00f067aa0ba902b7"
monkeypatch.setenv("traceparent", f"00-{fake_trace_id}-{fake_parent_id}-01")
init_tracer("test.external", FAKE_TRACE_SERVER_ADDRESS)
@instrument(span_name="follower")
def follower_func():
pass
follower_func()
assert trace_service.wait_for_spans(count=1)
span = trace_service.get_all_spans()[0]
assert span["trace_id"] == fake_trace_id
assert span["parent_span_id"] == fake_parent_id
+23 -84
View File
@@ -2,76 +2,19 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# type: ignore
import threading
from collections.abc import Iterable
from concurrent import futures
from typing import Callable, Generator, Literal
import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
ExportTraceServiceResponse,
)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
TraceServiceServicer,
add_TraceServiceServicer_to_server,
)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
import time
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE
from vllm import LLM, SamplingParams
from vllm.tracing import SpanAttributes
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
FieldName = Literal[
"bool_value", "string_value", "int_value", "double_value", "array_value"
]
def decode_value(value: AnyValue):
field_decoders: dict[FieldName, Callable] = {
"bool_value": (lambda v: v.bool_value),
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
"double_value": (lambda v: v.double_value),
"array_value": (
lambda v: [decode_value(item) for item in v.array_value.values]
),
}
for field, decoder in field_decoders.items():
if value.HasField(field):
return decoder(value)
raise ValueError(f"Couldn't decode value: {value}")
def decode_attributes(attributes: Iterable[KeyValue]):
return {kv.key: decode_value(kv.value) for kv in attributes}
class FakeTraceService(TraceServiceServicer):
def __init__(self):
self.request = None
self.evt = threading.Event()
def Export(self, request, context):
self.request = request
self.evt.set()
return ExportTraceServiceResponse()
@pytest.fixture
def trace_service() -> Generator[FakeTraceService, None, None]:
"""Fixture to set up a fake gRPC trace service"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
service = FakeTraceService()
add_TraceServiceServicer_to_server(service, server)
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
server.start()
yield service
server.stop(None)
# Import shared fixtures from the tracing conftest
from tests.tracing.conftest import ( # noqa: F401
FAKE_TRACE_SERVER_ADDRESS,
FakeTraceService,
trace_service,
)
def test_traces(
@@ -97,29 +40,25 @@ def test_traces(
outputs = llm.generate(prompts, sampling_params=sampling_params)
print(f"test_traces outputs is : {outputs}")
timeout = 10
if not trace_service.evt.wait(timeout):
raise TimeoutError(
f"The fake trace service didn't receive a trace within "
f"the {timeout} seconds timeout"
)
# Wait for the "llm_request" span to be exported.
# The BatchSpanProcessor batches spans and exports them periodically,
# so we need to wait specifically for the llm_request span to appear.
timeout = 15
deadline = time.time() + timeout
llm_request_spans = []
while time.time() < deadline:
all_spans = trace_service.get_all_spans()
llm_request_spans = [s for s in all_spans if s["name"] == "llm_request"]
if llm_request_spans:
break
time.sleep(0.5)
request = trace_service.request
assert len(request.resource_spans) == 1, (
f"Expected 1 resource span, but got {len(request.resource_spans)}"
)
assert len(request.resource_spans[0].scope_spans) == 1, (
f"Expected 1 scope span, "
f"but got {len(request.resource_spans[0].scope_spans)}"
)
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
f"Expected 1 span, "
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}"
assert len(llm_request_spans) == 1, (
f"Expected exactly 1 'llm_request' span, but got {len(llm_request_spans)}. "
f"All span names: {[s['name'] for s in all_spans]}"
)
attributes = decode_attributes(
request.resource_spans[0].scope_spans[0].spans[0].attributes
)
attributes = llm_request_spans[0]["attributes"]
# assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
assert (
+8
View File
@@ -33,6 +33,7 @@ from vllm.config.utils import Range, hash_factors
from vllm.logger import init_logger
from vllm.logging_utils import lazy
from vllm.platforms import current_platform
from vllm.tracing import instrument, instrument_manual
from vllm.utils.import_utils import resolve_obj_by_qualname
from .compiler_interface import (
@@ -234,6 +235,7 @@ class CompilerManager:
)
return compiled_graph
@instrument(span_name="Compile graph")
def compile(
self,
graph: fx.GraphModule,
@@ -497,6 +499,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False
@instrument(span_name="Inductor compilation")
def run(self, *args: Any) -> Any:
# maybe instead just assert inputs are fake?
fake_args = [
@@ -922,6 +925,11 @@ class VllmBackend:
)
self.compilation_config.compilation_time += dynamo_time
# Record Dynamo time in tracing if available
start_time = int(torch_compile_start_time * 1e9)
attributes = {"dynamo.time_seconds": dynamo_time}
instrument_manual("Dynamo bytecode transform", start_time, None, attributes)
# we control the compilation process, each instance can only be
# called once
assert not self._called, "VllmBackend can only be called once"
+2 -2
View File
@@ -122,9 +122,9 @@ class ObservabilityConfig:
@classmethod
def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None:
if value is not None:
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.tracing import is_tracing_available, otel_import_error_traceback
if not is_otel_available():
if not is_tracing_available():
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
+2
View File
@@ -50,6 +50,7 @@ from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tool_parsers import ToolParserManager
from vllm.tracing import instrument
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address
@@ -377,6 +378,7 @@ def validate_api_server_args(args):
)
@instrument(span_name="API server setup")
def setup_server(args):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
@@ -14,6 +14,7 @@ from vllm.model_executor.model_loader.utils import (
process_weights_after_loading,
)
from vllm.platforms import current_platform
from vllm.tracing import instrument
from vllm.utils.mem_utils import format_gib
from vllm.utils.torch_utils import set_default_torch_dtype
@@ -37,6 +38,7 @@ class BaseModelLoader(ABC):
inplace weights loading for an already-initialized model"""
raise NotImplementedError
@instrument(span_name="Load model")
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
@@ -30,6 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator,
safetensors_weights_iterator,
)
from vllm.tracing import instrument
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
logger = init_logger(__name__)
@@ -274,6 +275,7 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns_overrides=None,
)
@instrument(span_name="Load weights")
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao":
quant_config = get_quant_config(model_config, self.load_config)
@@ -23,11 +23,13 @@ from vllm.model_executor.model_loader.reload import (
set_torchao_reload_attrs,
)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.tracing import instrument
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
@instrument(span_name="Initialize model")
def initialize_model(
vllm_config: VllmConfig,
*,
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.quantization import (
get_quantization_config,
)
from vllm.platforms import current_platform
from vllm.tracing import instrument
from vllm.utils.import_utils import PlaceholderModule
try:
@@ -443,6 +444,7 @@ def download_gguf(
return local_files[0]
@instrument(span_name="Download weights - HF")
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: str | None,
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.tracing import instrument
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_mk_alignment_for_contiguous_layout,
@@ -358,6 +359,7 @@ def _count_warmup_iterations(model: torch.nn.Module, max_tokens: int) -> int:
return total
@instrument(span_name="DeepGemm warmup")
def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int):
total = _count_warmup_iterations(model, max_tokens)
if total == 0:
-135
View File
@@ -1,135 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Mapping
from vllm.logger import init_logger
from vllm.utils.func_utils import run_once
TRACE_HEADERS = ["traceparent", "tracestate"]
logger = init_logger(__name__)
_is_otel_imported = False
otel_import_error_traceback: str | None = None
try:
from opentelemetry.context.context import Context
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL,
)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
_is_otel_imported = True
except ImportError:
# Capture and format traceback to provide detailed context for the import
# error. Only the string representation of the error is retained to avoid
# memory leaks.
# See https://github.com/vllm-project/vllm/pull/7266#discussion_r1707395458
import traceback
otel_import_error_traceback = traceback.format_exc()
class Context: # type: ignore
pass
class BaseSpanAttributes: # type: ignore
pass
class SpanKind: # type: ignore
pass
class Tracer: # type: ignore
pass
def is_otel_available() -> bool:
return _is_otel_imported
def init_tracer(
instrumenting_module_name: str, otlp_traces_endpoint: str
) -> Tracer | None:
if not is_otel_available():
raise ValueError(
"OpenTelemetry is not available. Unable to initialize "
"a tracer. Ensure OpenTelemetry packages are installed. "
f"Original error:\n{otel_import_error_traceback}"
)
trace_provider = TracerProvider()
span_exporter = get_span_exporter(otlp_traces_endpoint)
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
set_tracer_provider(trace_provider)
tracer = trace_provider.get_tracer(instrumenting_module_name)
return tracer
def get_span_exporter(endpoint):
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
if protocol == "grpc":
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter,
)
elif protocol == "http/protobuf":
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter, # type: ignore
)
else:
raise ValueError(f"Unsupported OTLP protocol '{protocol}' is configured")
return OTLPSpanExporter(endpoint=endpoint)
def extract_trace_context(headers: Mapping[str, str] | None) -> Context | None:
if is_otel_available():
headers = headers or {}
return TraceContextTextMapPropagator().extract(headers)
else:
return None
def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
return {h: headers[h] for h in TRACE_HEADERS if h in headers}
class SpanAttributes:
# Attribute names copied from here to avoid version conflicts:
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens"
GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens"
GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
# Attribute names added until they are added to the semantic conventions:
GEN_AI_REQUEST_ID = "gen_ai.request.id"
GEN_AI_REQUEST_N = "gen_ai.request.n"
GEN_AI_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
GEN_AI_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e"
GEN_AI_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler"
# Time taken in the forward pass for this across all workers
GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward"
# Time taken in the model execute function. This will include model
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute"
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill"
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference"
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
return any(h in headers for h in TRACE_HEADERS)
@run_once
def log_tracing_disabled_warning() -> None:
logger.warning("Received a request with trace context but tracing is disabled")
+157
View File
@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from typing import Any, TypeAlias
# Import the implementation details
from .otel import (
SpanKind,
extract_trace_context,
init_otel_tracer,
init_otel_worker_tracer,
instrument_otel,
is_otel_available,
manual_instrument_otel,
otel_import_error_traceback,
)
from .utils import (
SpanAttributes,
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
__all__ = [
"instrument",
"instrument_manual",
"init_tracer",
"maybe_init_worker_tracer",
"is_tracing_available",
"SpanAttributes",
"SpanKind",
"extract_trace_context",
"extract_trace_headers",
"log_tracing_disabled_warning",
"contains_trace_headers",
"otel_import_error_traceback",
]
BackendAvailableFunc: TypeAlias = Callable[[], bool]
InstrumentFunc: TypeAlias = Callable[..., Any]
InstrumentManualFunc: TypeAlias = Callable[..., Any]
InitTracerFunc: TypeAlias = Callable[..., Any]
InitWorkerTracerFunc: TypeAlias = Callable[..., Any]
_REGISTERED_TRACING_BACKENDS: dict[
str,
tuple[
BackendAvailableFunc,
InitTracerFunc,
InitWorkerTracerFunc,
InstrumentFunc,
InstrumentManualFunc,
],
] = {
"otel": (
is_otel_available,
init_otel_tracer,
init_otel_worker_tracer,
instrument_otel,
manual_instrument_otel,
),
}
def init_tracer(
instrumenting_module_name: str,
otlp_traces_endpoint: str,
extra_attributes: dict[str, str] | None = None,
):
is_available, init_tracer_fn, _, _, _ = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return init_tracer_fn(
instrumenting_module_name, otlp_traces_endpoint, extra_attributes
)
def maybe_init_worker_tracer(
instrumenting_module_name: str,
process_kind: str,
process_name: str,
):
is_available, _, init_worker_tracer_fn, _, _ = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return init_worker_tracer_fn(
instrumenting_module_name, process_kind, process_name
)
def instrument(
obj: Callable | None = None,
*,
span_name: str = "",
attributes: dict[str, str] | None = None,
record_exception: bool = True,
):
"""
Generic decorator to instrument functions.
"""
if obj is None:
return functools.partial(
instrument,
span_name=span_name,
attributes=attributes,
record_exception=record_exception,
)
# Dispatch to OTel (and potentially others later)
is_available, _, _, otel_instrument, _ = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return otel_instrument(
func=obj,
span_name=span_name,
attributes=attributes,
record_exception=record_exception,
)
else:
return obj
def instrument_manual(
span_name: str,
start_time: int,
end_time: int | None = None,
attributes: dict[str, Any] | None = None,
context: Any = None,
kind: Any = None,
):
"""Manually create a span with explicit timestamps.
Args:
span_name: Name of the span to create.
start_time: Start time in nanoseconds since epoch.
end_time: Optional end time in nanoseconds. If None, ends immediately.
attributes: Optional dict of span attributes.
context: Optional trace context (e.g., from extract_trace_context).
kind: Optional SpanKind (e.g., SpanKind.SERVER).
"""
is_available, _, _, _, manual_instrument_fn = _REGISTERED_TRACING_BACKENDS["otel"]
if is_available():
return manual_instrument_fn(
span_name, start_time, end_time, attributes, context, kind
)
else:
return None
def is_tracing_available() -> bool:
"""
Returns True if any tracing backend (OTel, Profiler, etc.) is available.
Use this to guard expensive tracing logic in the main code.
"""
check_available = [
is_available
for is_available, _, _, _, _ in _REGISTERED_TRACING_BACKENDS.values()
]
return any(check_available)
+265
View File
@@ -0,0 +1,265 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import functools
import inspect
import os
import traceback
from collections.abc import Mapping
from contextlib import contextmanager
from typing import Any
from vllm.logger import init_logger
from vllm.tracing.utils import TRACE_HEADERS, LoadingSpanAttributes
logger = init_logger(__name__)
try:
from opentelemetry import trace
from opentelemetry.context.context import Context
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as OTLPGrpcExporter,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPHttpExporter,
)
from opentelemetry.propagate import inject
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_TRACES_PROTOCOL,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import (
SpanKind, # noqa: F401
Tracer,
set_tracer_provider,
)
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
_IS_OTEL_AVAILABLE = True
otel_import_error_traceback = None
except ImportError:
_IS_OTEL_AVAILABLE = False
otel_import_error_traceback = traceback.format_exc()
trace = None # type: ignore
Context = Any # type: ignore
Tracer = Any # type: ignore
inject = None # type: ignore
Resource = None # type: ignore
SpanKind = Any # type: ignore
def is_otel_available() -> bool:
return _IS_OTEL_AVAILABLE
def init_otel_tracer(
instrumenting_module_name: str,
otlp_traces_endpoint: str,
extra_attributes: dict[str, str] | None = None,
) -> Tracer:
"""Initializes the OpenTelemetry tracer provider."""
if not _IS_OTEL_AVAILABLE:
raise ValueError(
"OpenTelemetry is not available. Unable to initialize "
"a tracer. Ensure OpenTelemetry packages are installed. "
f"Original error:\n{otel_import_error_traceback}"
)
# Store the endpoint in environment so child processes can inherit it
os.environ["OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"] = otlp_traces_endpoint
resource_attrs = {}
resource_attrs["vllm.instrumenting_module_name"] = instrumenting_module_name
resource_attrs["vllm.process_id"] = str(os.getpid())
if extra_attributes:
resource_attrs.update(extra_attributes)
resource = Resource.create(resource_attrs)
trace_provider = TracerProvider(resource=resource)
span_exporter = get_span_exporter(otlp_traces_endpoint)
trace_provider.add_span_processor(BatchSpanProcessor(span_exporter))
set_tracer_provider(trace_provider)
atexit.register(trace_provider.shutdown)
tracer = trace_provider.get_tracer(instrumenting_module_name)
return tracer
def get_span_exporter(endpoint):
protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc")
if protocol == "grpc":
exporter = OTLPGrpcExporter(endpoint=endpoint, insecure=True)
elif protocol == "http/protobuf":
exporter = OTLPHttpExporter(endpoint=endpoint)
else:
raise ValueError(f"Unsupported OTLP protocol '{protocol}' is configured")
return exporter
def init_otel_worker_tracer(
instrumenting_module_name: str,
process_kind: str,
process_name: str,
) -> Tracer:
"""
Backend-specific initialization for OpenTelemetry in a worker process.
"""
# Initialize the tracer if an OTLP endpoint is configured.
# The endpoint is propagated via environment variable from the main process.
otlp_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")
if not otlp_endpoint:
return None
extra_attrs = {
"vllm.process_kind": process_kind,
"vllm.process_name": process_name,
}
return init_otel_tracer(instrumenting_module_name, otlp_endpoint, extra_attrs)
def extract_trace_context(headers: Mapping[str, str] | None) -> Context | None:
"""Extracts context from HTTP headers."""
if _IS_OTEL_AVAILABLE and headers:
return TraceContextTextMapPropagator().extract(headers)
return None
def instrument_otel(func, span_name, attributes, record_exception):
"""Internal wrapper logic for sync and async functions."""
# Pre-calculate static code attributes once (these don't change)
code_attrs = {
LoadingSpanAttributes.CODE_FUNCTION: func.__qualname__,
LoadingSpanAttributes.CODE_NAMESPACE: func.__module__,
LoadingSpanAttributes.CODE_FILEPATH: func.__code__.co_filename,
LoadingSpanAttributes.CODE_LINENO: str(func.__code__.co_firstlineno),
}
if attributes:
code_attrs.update(attributes)
final_span_name = span_name or func.__qualname__
module_name = func.__module__
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
tracer = trace.get_tracer(module_name)
ctx = _get_smart_context()
with (
tracer.start_as_current_span(
final_span_name,
context=ctx,
attributes=code_attrs,
record_exception=record_exception,
),
propagate_trace_to_env(),
):
return await func(*args, **kwargs)
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
tracer = trace.get_tracer(module_name)
ctx = _get_smart_context()
with (
tracer.start_as_current_span(
final_span_name,
context=ctx,
attributes=code_attrs,
record_exception=record_exception,
),
propagate_trace_to_env(),
):
return func(*args, **kwargs)
return async_wrapper if inspect.iscoroutinefunction(func) else sync_wrapper
def manual_instrument_otel(
span_name: str,
start_time: int,
end_time: int | None = None,
attributes: dict[str, Any] | None = None,
context: Context | None = None,
kind: Any = None, # SpanKind, but typed as Any for when OTEL unavailable
):
"""Manually create and end a span with explicit timestamps."""
if not _IS_OTEL_AVAILABLE:
return
tracer = trace.get_tracer(__name__)
# Use provided context, or fall back to smart context detection
ctx = context if context is not None else _get_smart_context()
span_kwargs: dict[str, Any] = {
"name": span_name,
"context": ctx,
"start_time": start_time,
}
if kind is not None:
span_kwargs["kind"] = kind
span = tracer.start_span(**span_kwargs)
if attributes:
span.set_attributes(attributes)
if end_time is not None:
span.end(end_time=end_time)
else:
span.end()
def _get_smart_context() -> Context | None:
"""
Determines the parent context.
1. If a Span is already active in this process, use it.
2. If not, extract from os.environ, handling the case-sensitivity mismatch.
"""
current_span = trace.get_current_span()
if current_span.get_span_context().is_valid:
return None
carrier = {}
if tp := os.environ.get("traceparent", os.environ.get("TRACEPARENT")): # noqa: SIM112
carrier["traceparent"] = tp
if ts := os.environ.get("tracestate", os.environ.get("TRACESTATE")): # noqa: SIM112
carrier["tracestate"] = ts
if not carrier:
carrier = dict(os.environ)
return TraceContextTextMapPropagator().extract(carrier)
@contextmanager
def propagate_trace_to_env():
"""
Temporarily injects the current OTel context into os.environ.
This ensures that any subprocesses (like vLLM workers) spawned
within this context inherit the correct traceparent.
"""
if not _IS_OTEL_AVAILABLE:
yield
return
# Capture original state of relevant keys
original_state = {k: os.environ.get(k) for k in TRACE_HEADERS}
try:
# inject() writes 'traceparent' and 'tracestate' to os.environ
inject(os.environ)
yield
finally:
# Restore original environment
for key, original_value in original_state.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
+72
View File
@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from vllm.logger import init_logger
from vllm.utils.func_utils import run_once
logger = init_logger(__name__)
# Standard W3C headers used for context propagation
TRACE_HEADERS = ["traceparent", "tracestate"]
class SpanAttributes:
"""
Standard attributes for spans.
These are largely based on OpenTelemetry Semantic Conventions but are defined
here as constants so they can be used by any backend or logger.
"""
# Attribute names copied from OTel semantic conventions to avoid version conflicts
GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens"
GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens"
GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens"
GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p"
GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature"
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
# Custom attributes added until they are standardized
GEN_AI_REQUEST_ID = "gen_ai.request.id"
GEN_AI_REQUEST_N = "gen_ai.request.n"
GEN_AI_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
GEN_AI_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e"
GEN_AI_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler"
# Latency breakdowns
GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward"
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute"
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill"
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference"
class LoadingSpanAttributes:
"""Custom attributes for code-level tracing (file, line number)."""
CODE_NAMESPACE = "code.namespace"
CODE_FUNCTION = "code.function"
CODE_FILEPATH = "code.filepath"
CODE_LINENO = "code.lineno"
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
"""Check if the provided headers dictionary contains trace context."""
return any(h in headers for h in TRACE_HEADERS)
def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]:
"""
Extract only trace-related headers from a larger header dictionary.
Useful for logging or passing context to a non-OTel client.
"""
return {h: headers[h] for h in TRACE_HEADERS if h in headers}
@run_once
def log_tracing_disabled_warning() -> None:
logger.warning("Received a request with trace context but tracing is disabled")
+6 -4
View File
@@ -110,6 +110,10 @@ class AsyncLLM(EngineClient):
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
tracing_endpoint = self.observability_config.otlp_traces_endpoint
if tracing_endpoint is not None:
init_tracer("vllm.llm_engine", tracing_endpoint)
self.log_requests = log_requests
custom_stat_loggers = list(stat_loggers or [])
@@ -136,10 +140,8 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer
if tracing_endpoint is not None:
self.output_processor.tracing_enabled = True
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client(
+21
View File
@@ -24,6 +24,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tracing import instrument, maybe_init_worker_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.gc_utils import (
freeze_gc_heap,
@@ -217,6 +218,7 @@ class EngineCore:
# environment variable overrides after this point)
enable_envs_cache()
@instrument(span_name="Prepare model")
def _initialize_kv_caches(
self, vllm_config: VllmConfig
) -> tuple[int, int, KVCacheConfig]:
@@ -658,6 +660,7 @@ class EngineCoreProc(EngineCore):
ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD"
@instrument(span_name="EngineCoreProc init")
def __init__(
self,
vllm_config: VllmConfig,
@@ -926,8 +929,18 @@ class EngineCoreProc(EngineCore):
data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
if data_parallel:
parallel_config.data_parallel_rank_local = local_dp_rank
maybe_init_worker_tracer(
instrumenting_module_name="vllm.engine_core",
process_kind="engine_core",
process_name=f"EngineCore_DP{dp_rank}",
)
set_process_title("EngineCore", f"DP{dp_rank}")
else:
maybe_init_worker_tracer(
instrumenting_module_name="vllm.engine_core",
process_kind="engine_core",
process_name="EngineCore",
)
set_process_title("EngineCore")
decorate_logs()
@@ -956,6 +969,7 @@ class EngineCoreProc(EngineCore):
parallel_config.data_parallel_rank = 0
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
assert engine_core is not None
engine_core.run_busy_loop()
except SystemExit:
@@ -1485,6 +1499,13 @@ class EngineCoreActorMixin:
dp_rank: int = 0,
local_dp_rank: int = 0,
):
# Initialize tracer for distributed tracing if configured.
maybe_init_worker_tracer(
instrumenting_module_name="vllm.engine_core",
process_kind="engine_core",
process_name=f"DPEngineCoreActor_DP{dp_rank}",
)
self.addresses = addresses
vllm_config.parallel_config.data_parallel_index = dp_rank
vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank
+4
View File
@@ -24,6 +24,7 @@ from vllm.envs import VLLM_ENGINE_READY_TIMEOUT_S
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.tracing import instrument
from vllm.utils.async_utils import in_loop
from vllm.utils.network_utils import (
close_sockets,
@@ -96,6 +97,7 @@ class EngineCoreClient(ABC):
return InprocClient(vllm_config, executor_class, log_stats)
@staticmethod
@instrument(span_name="Overall Loading")
def make_async_mp_client(
vllm_config: VllmConfig,
executor_class: type[Executor],
@@ -650,6 +652,7 @@ def _process_utility_output(
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""
@instrument(span_name="SyncMPClient init")
def __init__(
self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool
):
@@ -819,6 +822,7 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
@instrument(span_name="AsyncMPClient init")
def __init__(
self,
vllm_config: VllmConfig,
+2 -2
View File
@@ -100,8 +100,8 @@ class LLMEngine:
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer
init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracing_enabled = True
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
+54 -52
View File
@@ -20,7 +20,12 @@ from vllm.outputs import (
)
from vllm.sampling_params import RequestOutputKind
from vllm.tokenizers import TokenizerLike
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
from vllm.tracing import (
SpanAttributes,
SpanKind,
extract_trace_context,
instrument_manual,
)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
@@ -422,7 +427,7 @@ class OutputProcessor:
self.parent_requests: dict[str, ParentRequest] = {}
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None
self.tracing_enabled: bool = False
self._requests_drained = asyncio.Event()
self._requests_drained.set()
@@ -678,7 +683,7 @@ class OutputProcessor:
self._update_stats_from_finished(
req_state, finish_reason, iteration_stats
)
if self.tracer:
if self.tracing_enabled:
self.do_tracing(engine_core_output, req_state, iteration_stats)
return OutputProcessorOutput(
@@ -714,62 +719,59 @@ class OutputProcessor:
) -> None:
assert req_state.stats is not None
assert iteration_stats is not None
assert self.tracer is not None
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
metrics = req_state.stats
arrival_time_ns = int(metrics.arrival_time * 1e9)
trace_context = extract_trace_context(engine_core_output.trace_headers)
prompt_length = length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds
)
with self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
context=trace_context,
start_time=arrival_time_nano_seconds,
) as span:
metrics = req_state.stats
e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
queued_time = metrics.scheduled_ts - metrics.queued_ts
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
decode_time = metrics.last_token_ts - metrics.first_token_ts
inference_time = metrics.last_token_ts - metrics.scheduled_ts
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
metrics.first_token_latency,
)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length)
span.set_attribute(
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
metrics.num_generation_tokens,
)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time
)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time
)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time
)
# meta
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_ID, req_state.external_req_id
# Calculate timing metrics
e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
queued_time = metrics.scheduled_ts - metrics.queued_ts
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
decode_time = metrics.last_token_ts - metrics.first_token_ts
inference_time = metrics.last_token_ts - metrics.scheduled_ts
# Build attributes dict
attributes: dict[str, Any] = {
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN: (
metrics.first_token_latency
),
SpanAttributes.GEN_AI_LATENCY_E2E: e2e_time,
SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE: queued_time,
SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS: prompt_length,
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS: (
metrics.num_generation_tokens
),
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL: prefill_time,
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE: decode_time,
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE: inference_time,
SpanAttributes.GEN_AI_REQUEST_ID: req_state.external_req_id,
}
# Add optional request parameters
if req_state.top_p:
attributes[SpanAttributes.GEN_AI_REQUEST_TOP_P] = req_state.top_p
if req_state.max_tokens_param:
attributes[SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS] = (
req_state.max_tokens_param
)
if req_state.top_p:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p)
if req_state.max_tokens_param:
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param
)
if req_state.temperature:
span.set_attribute(
SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature
)
if req_state.n:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n)
if req_state.temperature:
attributes[SpanAttributes.GEN_AI_REQUEST_TEMPERATURE] = (
req_state.temperature
)
if req_state.n:
attributes[SpanAttributes.GEN_AI_REQUEST_N] = req_state.n
instrument_manual(
span_name="llm_request",
start_time=arrival_time_ns,
attributes=attributes,
context=trace_context,
kind=SpanKind.SERVER,
)
def _update_stats_from_output(
self,
+2
View File
@@ -15,6 +15,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.tracing import instrument
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest
@@ -84,6 +85,7 @@ class Executor(ABC):
)
return executor_class
@instrument(span_name="Executor init")
def __init__(
self,
vllm_config: VllmConfig,
+11
View File
@@ -41,6 +41,7 @@ from vllm.distributed.parallel_state import (
)
from vllm.envs import enable_envs_cache
from vllm.logger import init_logger
from vllm.tracing import instrument, maybe_init_worker_tracer
from vllm.utils.network_utils import (
get_distributed_init_method,
get_loopback_ip,
@@ -527,6 +528,7 @@ class WorkerProc:
)
)
@instrument(span_name="Worker init")
def __init__(
self,
vllm_config: VllmConfig,
@@ -740,6 +742,15 @@ class WorkerProc:
try:
reader.close()
# Initialize tracer
rank = kwargs.get("rank", 0)
maybe_init_worker_tracer(
instrumenting_module_name="vllm.worker",
process_kind="worker",
process_name=f"Worker_{rank}",
)
worker = WorkerProc(*args, **kwargs)
assert worker.worker_response_mq is not None
+3
View File
@@ -9,6 +9,7 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.tracing import instrument
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@@ -51,6 +52,7 @@ class CPUModelRunner(GPUModelRunner):
if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu
@instrument(span_name="Loading (CPU)")
def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
self.model = get_model(vllm_config=self.vllm_config)
@@ -61,6 +63,7 @@ class CPUModelRunner(GPUModelRunner):
def get_model(self) -> nn.Module:
return self.model
@instrument(span_name="Warmup (CPU)")
def warming_up_model(self) -> None:
logger.info("Warming up model for the compilation...")
# Only generate graph for the generic shape
+3
View File
@@ -93,6 +93,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.tracing import instrument
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.math_utils import cdiv, round_up
@@ -4111,6 +4112,7 @@ class GPUModelRunner(
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
@instrument(span_name="Loading (GPU)")
def load_model(self, eep_scale_up: bool = False) -> None:
"""
Args:
@@ -5165,6 +5167,7 @@ class GPUModelRunner(
self.encoder_cache.clear()
gc.collect()
@instrument(span_name="Capture model")
def capture_model(self) -> int:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning(
+4
View File
@@ -42,6 +42,7 @@ from vllm.platforms import current_platform
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.tracing import instrument
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
@@ -186,6 +187,7 @@ class Worker(WorkerBase):
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
@instrument(span_name="Init device")
def init_device(self):
if self.device_config.device_type == "cuda":
# This env var set by Ray causes exceptions with graph building.
@@ -407,6 +409,7 @@ class Worker(WorkerBase):
self.model_runner.update_max_model_len(max_model_len)
logger.debug("Updated max_model_len to %d", max_model_len)
@instrument(span_name="Allocate KV cache")
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
@@ -426,6 +429,7 @@ class Worker(WorkerBase):
else:
self.model_runner.initialize_kv_cache(kv_cache_config)
@instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> None:
warmup_sizes = []
+2
View File
@@ -11,6 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tracing import instrument
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.system_utils import update_environment_variables
from vllm.v1.kv_cache_interface import KVCacheSpec
@@ -222,6 +223,7 @@ class WorkerWrapperBase:
envs = envs_list[self.rpc_rank]
update_environment_variables(envs)
@instrument(span_name="Worker init")
def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
"""
Here we inject some common logic before initializing the worker.