TensorRT-LLMs/tensorrt_llm/serve/openai_disagg_service.py
Balaram Buddharaju 6a6e41f802
[TRTLLM-9468][chore] Update disagg benchmarking scripts to support context parallelism (#9720)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-12 22:29:41 -08:00

308 lines
13 KiB
Python

# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import os
from typing import Any, Callable, Dict, Optional
from tensorrt_llm.llmapi.disagg_utils import (
ConditionalDisaggConfig,
DisaggClusterConfig,
DisaggServerConfig,
MetadataServerConfig,
ServerRole,
)
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.cluster_storage import ClusterStorage, WatchEventType
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterManager, WorkerInfo
from tensorrt_llm.serve.metadata_server import JsonDictionary
from tensorrt_llm.serve.openai_client import OpenAIClient
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
CompletionRequest,
DisaggregatedParams,
UCompletionRequest,
UCompletionResponse,
)
from tensorrt_llm.serve.openai_service import OpenAIService
from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector
from tensorrt_llm.serve.responses_utils import (
ResponseHooks,
UCompletionResponseOrGenerator,
done_generator,
)
from tensorrt_llm.serve.router import KvCacheAwareRouter, Router
class OpenAIDisaggregatedService(OpenAIService):
def __init__(
self,
config: DisaggServerConfig,
ctx_router: Router,
gen_router: Router,
client_factory: Callable[[Router, ServerRole], OpenAIClient],
metadata_server: Optional[JsonDictionary] = None,
metadata_config: Optional[MetadataServerConfig] = None,
req_timeout_secs: int = 180,
server_start_timeout_secs: int = 180,
perf_metrics_collector: Optional[DisaggPerfMetricsCollector] = None,
disagg_cluster_storage: Optional[ClusterStorage] = None,
health_check_interval_secs: int = 3,
):
self._config = config
self._ctx_router = ctx_router
self._gen_router = gen_router
self._client_factory = client_factory
self._metadata_server = metadata_server
self._metadata_config = metadata_config
self._req_timeout_secs = req_timeout_secs
self._server_start_timeout_secs = server_start_timeout_secs
self._perf_metrics_collector = perf_metrics_collector
self._cluster_storage = disagg_cluster_storage
self._health_check_interval_secs = health_check_interval_secs
self._ctx_client = None
self._gen_client = None
self._disagg_cluster_manager = None
async def openai_completion(
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
) -> UCompletionResponseOrGenerator:
if not await self.is_ready():
raise RuntimeError("Cluster is not ready")
if not isinstance(request.prompt, str):
# Check if it's a list and contains integers
if type(request.prompt) is list and len(request.prompt) == 1:
request.prompt = request.prompt[0]
elif not isinstance(request.prompt, list) or not all(
isinstance(x, int) for x in request.prompt
):
raise ValueError(
"Disaggregated server currently only supports single string prompt or list of integers in request"
)
return await self._send_disagg_request(request, hooks)
async def openai_chat_completion(
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
) -> UCompletionResponseOrGenerator:
if not await self.is_ready():
raise RuntimeError("Cluster is not ready")
return await self._send_disagg_request(request, hooks)
async def _send_disagg_request(
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
) -> UCompletionResponseOrGenerator:
if hooks:
hooks.on_req_begin(request)
# empty server means client decides which server to use
reserved_gen_server = None
reserved_ctx_server = None
# reserve a gen_server if conditional disagg is needed
reserved_gen_server, need_ctx = await self._check_conditional_disagg(request)
need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
ctx_response = None
gen_req = request
if need_ctx:
ctx_req = self._get_ctx_request(request)
# ctx generator is empty
ctx_response = await self._ctx_client.send_request(
ctx_req, server=reserved_ctx_server, hooks=hooks
)
await self._verify_ctx_response(ctx_response)
gen_req = self._get_gen_request(request, ctx_response)
if ctx_response is None or self._need_gen(ctx_response):
return await self._gen_client.send_request(
gen_req, server=reserved_gen_server, hooks=hooks
)
else:
if request.stream:
# ctx client will never return a generator when streaming is requested
# make up for this by returning a done generator
return done_generator()
return ctx_response
def _need_gen(self, response: UCompletionResponse) -> bool:
if response and response.choices[0].finish_reason not in ["length", "not_finished"]:
del response.choices[0].disaggregated_params
return False
return True
def _get_ctx_request(self, request: UCompletionRequest) -> UCompletionRequest:
ctx_request = copy.deepcopy(request)
ctx_request.disaggregated_params = DisaggregatedParams(request_type="context_only")
ctx_request.stream = False
ctx_request.stream_options = None
return ctx_request
def _get_gen_request(
self,
request: UCompletionRequest,
ctx_response: UCompletionResponse,
) -> UCompletionRequest:
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
request.disaggregated_params.request_type = "generation_only"
# Replace the string prompt with prompt_tokens_ids
if isinstance(request, CompletionRequest):
request.prompt = ctx_response.prompt_token_ids
elif isinstance(request, ChatCompletionRequest):
request.prompt_token_ids = ctx_response.prompt_token_ids
return request
async def _check_conditional_disagg(self, request: UCompletionRequest) -> bool:
if self.conditional_disagg_config:
assert isinstance(self._gen_router, KvCacheAwareRouter)
# Query kv cache status and select a best gen_server.
# The server is reserved for generation request
gen_server, info = await self._gen_router.get_next_server(request)
match_length = sum(info["matches"])
total_length = sum(len(token_list) for token_list in info["token_lists"])
if (
match_length == 0
or total_length - match_length
> self.conditional_disagg_config.max_local_prefill_length
):
return gen_server, True
return gen_server, False
return None, True
async def _check_gen_only_disagg(self, request: UCompletionRequest) -> bool:
if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1":
# Hard-code first token, ctx_request_id for testing
request.disaggregated_params = DisaggregatedParams(
request_type="generation_only",
first_gen_tokens=[7],
ctx_request_id=1,
encoded_opaque_state=None,
draft_tokens=None,
)
request.ignore_eos = True
return True
return False
async def cluster_info(self) -> Dict[str, Any]:
cluster_info = {"is_ready": await self.is_ready()}
if self._disagg_cluster_manager:
cluster_info.update(await self._disagg_cluster_manager.cluster_info())
return cluster_info
async def is_ready(self) -> bool:
if self._disagg_cluster_manager:
return await self._disagg_cluster_manager.is_ready()
return True
@property
def disagg_cluster_config(self) -> Optional[DisaggClusterConfig]:
return self._config.disagg_cluster_config
@property
def conditional_disagg_config(self) -> Optional[ConditionalDisaggConfig]:
return self._config.conditional_disagg_config
async def setup(self) -> None:
self._ctx_client = self._client_factory(
self._ctx_router, ServerRole.CONTEXT, self._config.max_retries
)
self._gen_client = self._client_factory(
self._gen_router, ServerRole.GENERATION, self._config.max_retries
)
if self.disagg_cluster_config and self._cluster_storage:
logger.info("Starting disagg cluster manager")
self._disagg_cluster_manager = DisaggClusterManager(
self.disagg_cluster_config, self._cluster_storage
)
await self._disagg_cluster_manager.start()
await self._disagg_cluster_manager.watch_workers(on_event=self._on_worker_event)
logger.info("Disagg cluster manager started")
else:
if self._metadata_server and self._metadata_config:
logger.info("Starting server monitoring via metadata service")
await self._ctx_router.start_server_monitoring(
self._metadata_config.refresh_interval
)
await self._gen_router.start_server_monitoring(
self._metadata_config.refresh_interval
)
await self._wait_for_all_servers_ready()
async def teardown(self) -> None:
await self._ctx_client.shutdown()
await self._gen_client.shutdown()
if self._disagg_cluster_manager:
await self._disagg_cluster_manager.stop()
if self._metadata_server:
await self._ctx_router.stop_server_monitoring()
await self._gen_router.stop_server_monitoring()
async def _wait_for_all_servers_ready(self) -> None:
# Skip context servers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set
gen_only = os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1"
async def check_servers_ready():
elapsed_time = 0
interval = self._health_check_interval_secs
while elapsed_time < self._server_start_timeout_secs:
if gen_only:
unready_ctx_servers = []
else:
_, unready_ctx_servers = await self._ctx_client.check_ready()
_, unready_gen_servers = await self._gen_client.check_ready()
if len(unready_ctx_servers) == 0 and len(unready_gen_servers) == 0:
if gen_only:
logger.info("Generation servers are ready (context servers skipped)")
else:
logger.info("All servers are ready")
return
logger.info(
f"Waiting for servers, context: {unready_ctx_servers}, generation: {unready_gen_servers}"
)
await asyncio.sleep(interval)
elapsed_time += interval
try:
await asyncio.wait_for(check_servers_ready(), timeout=self._server_start_timeout_secs)
except asyncio.TimeoutError:
raise TimeoutError("Timeout waiting for context and generation servers to be ready")
async def _on_worker_event(self, worker_info: WorkerInfo, event_type: WatchEventType):
router_map = {ServerRole.CONTEXT: self._ctx_router, ServerRole.GENERATION: self._gen_router}
worker_addr = f"{worker_info.host}:{worker_info.port}"
try:
router = router_map[worker_info.role]
if event_type == WatchEventType.SET:
await router.add_server(worker_addr)
elif event_type == WatchEventType.DELETE:
await router.remove_server(worker_addr)
logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}, {worker_addr}")
except KeyError:
logger.error(
f"Unknown worker role: {worker_info.role}, Worker {worker_info.worker_id} event: {event_type.name}"
)
async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None:
if ctx_response:
if len(ctx_response.choices) != 1:
raise ValueError(
f"Context server returned {len(ctx_response.choices)} choices, expecting 1."
)
if ctx_response.choices[0].disaggregated_params is None:
raise ValueError("Context server did not return disaggregated params")
if ctx_response.choices[0].disaggregated_params.ctx_request_id is None:
raise ValueError("Invalid disaggregated params in context phase response.")
return ctx_response