mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
242 lines
13 KiB
Python
242 lines
13 KiB
Python
# Copyright (c) 2025, NVIDIA CORPORATION.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#!/usr/bin/env python
|
|
|
|
# yapf: disable
|
|
import asyncio
|
|
import signal
|
|
import traceback
|
|
from contextlib import asynccontextmanager
|
|
from typing import Callable, Optional
|
|
|
|
import aiohttp
|
|
import uvicorn
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
|
|
|
# yapf: disable
|
|
from tensorrt_llm.executor import CppExecutorError
|
|
from tensorrt_llm.executor.executor import CppExecutorError
|
|
from tensorrt_llm.llmapi import tracing
|
|
from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig,
|
|
MetadataServerConfig, ServerRole,
|
|
get_ctx_gen_server_addrs)
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.serve.cluster_storage import (HttpClusterStorageServer,
|
|
create_cluster_storage)
|
|
from tensorrt_llm.serve.metadata_server import create_metadata_server
|
|
from tensorrt_llm.serve.openai_client import OpenAIClient, OpenAIHttpClient
|
|
from tensorrt_llm.serve.openai_disagg_service import (
|
|
OpenAIDisaggregatedService, ResponseHooks)
|
|
from tensorrt_llm.serve.openai_protocol import (UCompletionRequest,
|
|
UCompletionResponse)
|
|
from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector
|
|
from tensorrt_llm.serve.responses_utils import (ServerArrivalTimeMiddleware,
|
|
get_steady_clock_now_in_seconds)
|
|
from tensorrt_llm.serve.router import Router, create_router
|
|
from tensorrt_llm.version import __version__ as VERSION
|
|
|
|
# yapf: enale
|
|
TIMEOUT_KEEP_ALIVE = 10 # seconds.
|
|
|
|
class RawRequestResponseHooks(ResponseHooks):
|
|
def __init__(self, raw_req: Request, perf_metrics_collector: DisaggPerfMetricsCollector):
|
|
self.raw_req = raw_req
|
|
self.ctx_server = ""
|
|
self.gen_server = ""
|
|
self.server_first_token_time = 0
|
|
self.perf_metrics_collector = perf_metrics_collector
|
|
|
|
def on_req_begin(self, request: UCompletionRequest):
|
|
...
|
|
|
|
def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse):
|
|
self.ctx_server = ctx_server
|
|
logger.debug(f"Received context response from {ctx_server} for request {response.choices[0].disaggregated_params.ctx_request_id}")
|
|
|
|
def on_first_token(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
|
|
self.gen_server = gen_server
|
|
self.server_first_token_time = get_steady_clock_now_in_seconds()
|
|
logger.debug(f"Received first token from {gen_server} for request {request.disaggregated_params.ctx_request_id}")
|
|
|
|
def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
|
|
if request.disaggregated_params:
|
|
ctx_req_id = request.disaggregated_params.ctx_request_id
|
|
asyncio.create_task(self.perf_metrics_collector.add_per_request_metrics(self.ctx_server, gen_server, ctx_req_id, self.raw_req.state.server_arrival_time, self.server_first_token_time))
|
|
|
|
|
|
class OpenAIDisaggServer:
|
|
|
|
def __init__(self,
|
|
config: DisaggServerConfig,
|
|
req_timeout_secs: int = 180,
|
|
server_start_timeout_secs: int = 180,
|
|
metadata_server_cfg: Optional[MetadataServerConfig] = None,
|
|
metrics_interval_secs: int = 0):
|
|
self._config = config
|
|
self._req_timeout_secs = req_timeout_secs
|
|
self._server_start_timeout_secs = server_start_timeout_secs
|
|
self._metadata_server_cfg = metadata_server_cfg
|
|
self._metrics_interval_secs = metrics_interval_secs
|
|
|
|
self._ctx_servers, self._gen_servers = get_ctx_gen_server_addrs(config.server_configs)
|
|
self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
|
|
self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg))
|
|
self._metadata_server = create_metadata_server(metadata_server_cfg)
|
|
self._perf_metrics_collector = DisaggPerfMetricsCollector(config.perf_metrics_max_requests)
|
|
|
|
self._disagg_cluster_storage = create_cluster_storage(config.disagg_cluster_config.cluster_uri, config.disagg_cluster_config.cluster_name) if config.disagg_cluster_config else None
|
|
|
|
self._service = OpenAIDisaggregatedService(
|
|
self._config, self._ctx_router, self._gen_router, self._create_client,
|
|
metadata_server=self._metadata_server,
|
|
metadata_config=self._metadata_server_cfg,
|
|
req_timeout_secs=self._req_timeout_secs,
|
|
server_start_timeout_secs=self._server_start_timeout_secs,
|
|
perf_metrics_collector=self._perf_metrics_collector,
|
|
disagg_cluster_storage=self._disagg_cluster_storage)
|
|
|
|
try:
|
|
otlp_cfg = config.otlp_config
|
|
if otlp_cfg and otlp_cfg.otlp_traces_endpoint:
|
|
tracing.init_tracer("trt.llm", otlp_cfg.otlp_traces_endpoint)
|
|
logger.info(
|
|
f"Initialized OTLP tracer successfully, endpoint: {otlp_cfg.otlp_traces_endpoint}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize OTLP tracer: {e}")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app) -> None:
|
|
await self._service.setup()
|
|
await self._set_steady_clock_offsets()
|
|
yield
|
|
await self._service.teardown()
|
|
|
|
self.app = FastAPI(lifespan=lifespan)
|
|
|
|
self.app.add_middleware(ServerArrivalTimeMiddleware)
|
|
|
|
@self.app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(_, exc):
|
|
return JSONResponse(status_code=400, content={"error": str(exc)})
|
|
|
|
self.register_routes()
|
|
|
|
def _create_client(self, router: Router, role: ServerRole, max_retries: int = 1) -> OpenAIClient:
|
|
client = OpenAIHttpClient(router, role, self._req_timeout_secs, max_retries)
|
|
self._perf_metrics_collector.add_client(client)
|
|
return client
|
|
|
|
def register_routes(self):
|
|
self.app.add_api_route("/v1/completions", self._wrap_entry_point(self._service.openai_completion), methods=["POST"])
|
|
self.app.add_api_route("/v1/chat/completions", self._wrap_entry_point(self._service.openai_chat_completion), methods=["POST"])
|
|
self.app.add_api_route("/health", self.health, methods=["GET"])
|
|
self.app.add_api_route("/cluster_info", self.cluster_info, methods=["GET"])
|
|
self.app.add_api_route("/version", self.version, methods=["GET"])
|
|
self.app.add_api_route("/perf_metrics", self._perf_metrics_collector.get_perf_metrics, methods=["GET"])
|
|
# import prometheus_client lazily to break the `set_prometheus_multiproc_dir`
|
|
from prometheus_client import make_asgi_app
|
|
self.app.mount("/prometheus/metrics", make_asgi_app())
|
|
if self._disagg_cluster_storage and isinstance(self._disagg_cluster_storage, HttpClusterStorageServer):
|
|
self._disagg_cluster_storage.add_routes(self.app)
|
|
|
|
def _wrap_entry_point(self, entry_point: Callable) -> Callable:
|
|
async def wrapper(req: UCompletionRequest, raw_req: Request) -> Response:
|
|
try:
|
|
hooks = RawRequestResponseHooks(raw_req, self._perf_metrics_collector)
|
|
response_or_generator = await entry_point(req, hooks)
|
|
if req.stream:
|
|
return StreamingResponse(content=response_or_generator, media_type="text/event-stream")
|
|
else:
|
|
return JSONResponse(content=response_or_generator.model_dump())
|
|
except Exception as e:
|
|
self._handle_exception(e)
|
|
return wrapper
|
|
|
|
def _handle_exception(self, exception):
|
|
if isinstance(exception, CppExecutorError):
|
|
logger.error("CppExecutorError: ", traceback.format_exc())
|
|
signal.raise_signal(signal.SIGINT)
|
|
elif isinstance(exception, HTTPException):
|
|
logger.error(f"HTTPException {exception.status_code} {exception.detail}: ", traceback.format_exc())
|
|
raise exception
|
|
else:
|
|
logger.error("Internal server error: ", traceback.format_exc())
|
|
raise HTTPException(status_code=500, detail=f"Internal server error {str(exception)}")
|
|
|
|
|
|
async def health(self) -> Response:
|
|
if not await self._service.is_ready():
|
|
return Response(status_code=500)
|
|
return Response(status_code=200)
|
|
|
|
async def cluster_info(self) -> JSONResponse:
|
|
return JSONResponse(content=await self._service.cluster_info())
|
|
|
|
async def version(self) -> JSONResponse:
|
|
return JSONResponse(content={"version": VERSION})
|
|
|
|
async def __call__(self, host: str, port: int):
|
|
config = uvicorn.Config(self.app,
|
|
host=host,
|
|
port=port,
|
|
log_level=logger.level,
|
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
|
await uvicorn.Server(config).serve()
|
|
|
|
# TODO: rework this for service discovery, now it's only for static server list
|
|
async def _set_steady_clock_offsets(self):
|
|
STEADY_CLOCK_OFFSET_ENDPOINT = "/steady_clock_offset"
|
|
async def query_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> tuple[Optional[float], Optional[float]]:
|
|
try:
|
|
originate_ts = get_steady_clock_now_in_seconds()
|
|
async with session.get(server_url + STEADY_CLOCK_OFFSET_ENDPOINT) as response:
|
|
destination_ts = get_steady_clock_now_in_seconds()
|
|
if response.status == 200:
|
|
response_content = await response.json()
|
|
# Compute the steady clock timestamp difference using the NTP clock synchronization algorithm. https://en.wikipedia.org/wiki/Network_Time_Protocol#Clock_synchronization_algorithm
|
|
receive_ts = response_content['receive_ts']
|
|
transmit_ts = response_content['transmit_ts']
|
|
delay = (destination_ts - originate_ts) - (transmit_ts - receive_ts)
|
|
offset = ((receive_ts - originate_ts) + (transmit_ts - destination_ts)) / 2
|
|
return delay, offset
|
|
else:
|
|
return None, None
|
|
except Exception:
|
|
return None, None
|
|
|
|
async def set_steady_clock_offset(session: aiohttp.ClientSession, server_url: str, offset: float) -> None:
|
|
payload = {"offset": offset}
|
|
async with session.post(server_url + STEADY_CLOCK_OFFSET_ENDPOINT, json=payload) as response:
|
|
if response.status != 200:
|
|
logger.warning(f"Cannot set disagg server steady clock offset for server {server_url}, the perf metrics timestamps could be mis-aligned")
|
|
|
|
async def align_steady_clock_offset(session: aiohttp.ClientSession, server_url: str) -> None:
|
|
server_url = f"http://{server_url}" if not server_url.startswith("http://") else server_url
|
|
delay, offset = await query_steady_clock_offset(session, server_url)
|
|
if delay is None or offset is None:
|
|
logger.warning(f"Unable to measure steady clock offset for {server_url}; skipping adjustment")
|
|
return
|
|
logger.info(f'Server: {server_url}, delay: {delay} second, offset: {offset} second')
|
|
# Negate the offset so that worker servers can adjust their steady clock by adding the new offset
|
|
await set_steady_clock_offset(session, server_url, -offset)
|
|
|
|
async with aiohttp.ClientSession(
|
|
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
|
|
timeout=aiohttp.ClientTimeout(total=self._req_timeout_secs)) as session:
|
|
await asyncio.gather(*[align_steady_clock_offset(session, server_url) for server_url in self._ctx_servers + self._gen_servers])
|