mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
308 lines
13 KiB
Python
308 lines
13 KiB
Python
# Copyright (c) 2025, NVIDIA CORPORATION.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import asyncio
|
|
import copy
|
|
import os
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
from tensorrt_llm.llmapi.disagg_utils import (
|
|
ConditionalDisaggConfig,
|
|
DisaggClusterConfig,
|
|
DisaggServerConfig,
|
|
MetadataServerConfig,
|
|
ServerRole,
|
|
)
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.serve.cluster_storage import ClusterStorage, WatchEventType
|
|
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterManager, WorkerInfo
|
|
from tensorrt_llm.serve.metadata_server import JsonDictionary
|
|
from tensorrt_llm.serve.openai_client import OpenAIClient
|
|
from tensorrt_llm.serve.openai_protocol import (
|
|
ChatCompletionRequest,
|
|
CompletionRequest,
|
|
DisaggregatedParams,
|
|
UCompletionRequest,
|
|
UCompletionResponse,
|
|
)
|
|
from tensorrt_llm.serve.openai_service import OpenAIService
|
|
from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector
|
|
from tensorrt_llm.serve.responses_utils import (
|
|
ResponseHooks,
|
|
UCompletionResponseOrGenerator,
|
|
done_generator,
|
|
)
|
|
from tensorrt_llm.serve.router import KvCacheAwareRouter, Router
|
|
|
|
|
|
class OpenAIDisaggregatedService(OpenAIService):
|
|
def __init__(
|
|
self,
|
|
config: DisaggServerConfig,
|
|
ctx_router: Router,
|
|
gen_router: Router,
|
|
client_factory: Callable[[Router, ServerRole], OpenAIClient],
|
|
metadata_server: Optional[JsonDictionary] = None,
|
|
metadata_config: Optional[MetadataServerConfig] = None,
|
|
req_timeout_secs: int = 180,
|
|
server_start_timeout_secs: int = 180,
|
|
perf_metrics_collector: Optional[DisaggPerfMetricsCollector] = None,
|
|
disagg_cluster_storage: Optional[ClusterStorage] = None,
|
|
health_check_interval_secs: int = 3,
|
|
):
|
|
self._config = config
|
|
self._ctx_router = ctx_router
|
|
self._gen_router = gen_router
|
|
self._client_factory = client_factory
|
|
self._metadata_server = metadata_server
|
|
self._metadata_config = metadata_config
|
|
self._req_timeout_secs = req_timeout_secs
|
|
self._server_start_timeout_secs = server_start_timeout_secs
|
|
self._perf_metrics_collector = perf_metrics_collector
|
|
self._cluster_storage = disagg_cluster_storage
|
|
self._health_check_interval_secs = health_check_interval_secs
|
|
|
|
self._ctx_client = None
|
|
self._gen_client = None
|
|
self._disagg_cluster_manager = None
|
|
|
|
async def openai_completion(
|
|
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
|
|
) -> UCompletionResponseOrGenerator:
|
|
if not await self.is_ready():
|
|
raise RuntimeError("Cluster is not ready")
|
|
if not isinstance(request.prompt, str):
|
|
# Check if it's a list and contains integers
|
|
if type(request.prompt) is list and len(request.prompt) == 1:
|
|
request.prompt = request.prompt[0]
|
|
elif not isinstance(request.prompt, list) or not all(
|
|
isinstance(x, int) for x in request.prompt
|
|
):
|
|
raise ValueError(
|
|
"Disaggregated server currently only supports single string prompt or list of integers in request"
|
|
)
|
|
|
|
return await self._send_disagg_request(request, hooks)
|
|
|
|
async def openai_chat_completion(
|
|
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
|
|
) -> UCompletionResponseOrGenerator:
|
|
if not await self.is_ready():
|
|
raise RuntimeError("Cluster is not ready")
|
|
return await self._send_disagg_request(request, hooks)
|
|
|
|
async def _send_disagg_request(
|
|
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
|
|
) -> UCompletionResponseOrGenerator:
|
|
if hooks:
|
|
hooks.on_req_begin(request)
|
|
# empty server means client decides which server to use
|
|
reserved_gen_server = None
|
|
reserved_ctx_server = None
|
|
# reserve a gen_server if conditional disagg is needed
|
|
reserved_gen_server, need_ctx = await self._check_conditional_disagg(request)
|
|
need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
|
|
ctx_response = None
|
|
gen_req = request
|
|
if need_ctx:
|
|
ctx_req = self._get_ctx_request(request)
|
|
# ctx generator is empty
|
|
ctx_response = await self._ctx_client.send_request(
|
|
ctx_req, server=reserved_ctx_server, hooks=hooks
|
|
)
|
|
await self._verify_ctx_response(ctx_response)
|
|
gen_req = self._get_gen_request(request, ctx_response)
|
|
if ctx_response is None or self._need_gen(ctx_response):
|
|
return await self._gen_client.send_request(
|
|
gen_req, server=reserved_gen_server, hooks=hooks
|
|
)
|
|
else:
|
|
if request.stream:
|
|
# ctx client will never return a generator when streaming is requested
|
|
# make up for this by returning a done generator
|
|
return done_generator()
|
|
return ctx_response
|
|
|
|
def _need_gen(self, response: UCompletionResponse) -> bool:
|
|
if response and response.choices[0].finish_reason not in ["length", "not_finished"]:
|
|
del response.choices[0].disaggregated_params
|
|
return False
|
|
return True
|
|
|
|
def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest:
|
|
ctx_request = copy.deepcopy(request)
|
|
ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only")
|
|
ctx_request.stream = False
|
|
ctx_request.stream_options = None
|
|
return ctx_request
|
|
|
|
def _get_gen_request(
|
|
self,
|
|
request: UCompletionRequest,
|
|
ctx_response: UCompletionResponse,
|
|
) -> UCompletionRequest:
|
|
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
|
|
request.disaggregated_params.request_type = "generation_only"
|
|
# Replace the string prompt with prompt_tokens_ids
|
|
if isinstance(request, CompletionRequest):
|
|
request.prompt = ctx_response.prompt_token_ids
|
|
elif isinstance(request, ChatCompletionRequest):
|
|
request.prompt_token_ids = ctx_response.prompt_token_ids
|
|
return request
|
|
|
|
async def _check_conditional_disagg(self, request: UCompletionRequest) -> bool:
|
|
if self.conditional_disagg_config:
|
|
assert isinstance(self._gen_router, KvCacheAwareRouter)
|
|
# Query kv cache status and select a best gen_server.
|
|
# The server is reserved for generation request
|
|
gen_server, info = await self._gen_router.get_next_server(request)
|
|
match_length = sum(info["matches"])
|
|
total_length = sum(len(token_list) for token_list in info["token_lists"])
|
|
if (
|
|
match_length == 0
|
|
or total_length - match_length
|
|
> self.conditional_disagg_config.max_local_prefill_length
|
|
):
|
|
return gen_server, True
|
|
return gen_server, False
|
|
return None, True
|
|
|
|
async def _check_gen_only_disagg(self, request: UCompletionRequest) -> bool:
|
|
if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1":
|
|
# Hard-code first token, ctx_request_id for testing
|
|
request.disaggregated_params = DisaggregatedParams(
|
|
request_type="generation_only",
|
|
first_gen_tokens=[7],
|
|
ctx_request_id=1,
|
|
encoded_opaque_state=None,
|
|
draft_tokens=None,
|
|
)
|
|
request.ignore_eos = True
|
|
return True
|
|
return False
|
|
|
|
async def cluster_info(self) -> Dict[str, Any]:
|
|
cluster_info = {"is_ready": await self.is_ready()}
|
|
if self._disagg_cluster_manager:
|
|
cluster_info.update(await self._disagg_cluster_manager.cluster_info())
|
|
return cluster_info
|
|
|
|
async def is_ready(self) -> bool:
|
|
if self._disagg_cluster_manager:
|
|
return await self._disagg_cluster_manager.is_ready()
|
|
return True
|
|
|
|
@property
|
|
def disagg_cluster_config(self) -> Optional[DisaggClusterConfig]:
|
|
return self._config.disagg_cluster_config
|
|
|
|
@property
|
|
def conditional_disagg_config(self) -> Optional[ConditionalDisaggConfig]:
|
|
return self._config.conditional_disagg_config
|
|
|
|
async def setup(self) -> None:
|
|
self._ctx_client = self._client_factory(
|
|
self._ctx_router, ServerRole.CONTEXT, self._config.max_retries
|
|
)
|
|
self._gen_client = self._client_factory(
|
|
self._gen_router, ServerRole.GENERATION, self._config.max_retries
|
|
)
|
|
|
|
if self.disagg_cluster_config and self._cluster_storage:
|
|
logger.info("Starting disagg cluster manager")
|
|
self._disagg_cluster_manager = DisaggClusterManager(
|
|
self.disagg_cluster_config, self._cluster_storage
|
|
)
|
|
await self._disagg_cluster_manager.start()
|
|
await self._disagg_cluster_manager.watch_workers(on_event=self._on_worker_event)
|
|
logger.info("Disagg cluster manager started")
|
|
else:
|
|
if self._metadata_server and self._metadata_config:
|
|
logger.info("Starting server monitoring via metadata service")
|
|
await self._ctx_router.start_server_monitoring(
|
|
self._metadata_config.refresh_interval
|
|
)
|
|
await self._gen_router.start_server_monitoring(
|
|
self._metadata_config.refresh_interval
|
|
)
|
|
await self._wait_for_all_servers_ready()
|
|
|
|
async def teardown(self) -> None:
|
|
await self._ctx_client.shutdown()
|
|
await self._gen_client.shutdown()
|
|
|
|
if self._disagg_cluster_manager:
|
|
await self._disagg_cluster_manager.stop()
|
|
|
|
if self._metadata_server:
|
|
await self._ctx_router.stop_server_monitoring()
|
|
await self._gen_router.stop_server_monitoring()
|
|
|
|
async def _wait_for_all_servers_ready(self) -> None:
|
|
# Skip context servers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set
|
|
gen_only = os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1"
|
|
|
|
async def check_servers_ready():
|
|
elapsed_time = 0
|
|
interval = self._health_check_interval_secs
|
|
while elapsed_time < self._server_start_timeout_secs:
|
|
if gen_only:
|
|
unready_ctx_servers = []
|
|
else:
|
|
_, unready_ctx_servers = await self._ctx_client.check_ready()
|
|
_, unready_gen_servers = await self._gen_client.check_ready()
|
|
if len(unready_ctx_servers) == 0 and len(unready_gen_servers) == 0:
|
|
if gen_only:
|
|
logger.info("Generation servers are ready (context servers skipped)")
|
|
else:
|
|
logger.info("All servers are ready")
|
|
return
|
|
logger.info(
|
|
f"Waiting for servers, context: {unready_ctx_servers}, generation: {unready_gen_servers}"
|
|
)
|
|
await asyncio.sleep(interval)
|
|
elapsed_time += interval
|
|
|
|
try:
|
|
await asyncio.wait_for(check_servers_ready(), timeout=self._server_start_timeout_secs)
|
|
except asyncio.TimeoutError:
|
|
raise TimeoutError("Timeout waiting for context and generation servers to be ready")
|
|
|
|
async def _on_worker_event(self, worker_info: WorkerInfo, event_type: WatchEventType):
|
|
router_map = {ServerRole.CONTEXT: self._ctx_router, ServerRole.GENERATION: self._gen_router}
|
|
worker_addr = f"{worker_info.host}:{worker_info.port}"
|
|
try:
|
|
router = router_map[worker_info.role]
|
|
if event_type == WatchEventType.SET:
|
|
await router.add_server(worker_addr)
|
|
elif event_type == WatchEventType.DELETE:
|
|
await router.remove_server(worker_addr)
|
|
logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}, {worker_addr}")
|
|
except KeyError:
|
|
logger.error(
|
|
f"Unknown worker role: {worker_info.role}, Worker {worker_info.worker_id} event: {event_type.name}"
|
|
)
|
|
|
|
async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None:
|
|
if ctx_response:
|
|
if len(ctx_response.choices) != 1:
|
|
raise ValueError(
|
|
f"Context server returned {len(ctx_response.choices)} choices, expecting 1."
|
|
)
|
|
if ctx_response.choices[0].disaggregated_params is None:
|
|
raise ValueError("Context server did not return disaggregated params")
|
|
if ctx_response.choices[0].disaggregated_params.ctx_request_id is None:
|
|
raise ValueError("Invalid disaggregated params in context phase response.")
|
|
return ctx_response
|