From f9c4bdf6cfed1be6c91fc7a845e32dbc4c3ba66a Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Wed, 4 Feb 2026 04:46:11 +0800 Subject: [PATCH] [TRTLLM-8921][feat] implement gen-first disagg_service (#11020) Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- .../_torch/pyexecutor/kv_cache_transceiver.py | 18 ++- tensorrt_llm/_torch/pyexecutor/py_executor.py | 20 +++ tensorrt_llm/disaggregated_params.py | 8 + tensorrt_llm/executor/base_worker.py | 6 + tensorrt_llm/executor/executor.py | 3 + tensorrt_llm/llmapi/__init__.py | 3 +- tensorrt_llm/llmapi/disagg_utils.py | 2 + tensorrt_llm/llmapi/llm.py | 11 ++ tensorrt_llm/serve/openai_disagg_server.py | 2 - tensorrt_llm/serve/openai_disagg_service.py | 81 ++++++++-- tensorrt_llm/serve/openai_protocol.py | 12 +- tensorrt_llm/serve/openai_server.py | 6 + tensorrt_llm/serve/router.py | 23 ++- .../integration/test_lists/test-db/l0_a10.yml | 1 + .../api_stability/references/llm.yaml | 3 + .../test_openai_disagg_service.py | 153 ++++++++++++++++++ 16 files changed, 327 insertions(+), 25 deletions(-) create mode 100644 tests/unittest/disaggregated/test_openai_disagg_service.py diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index 8fe669456e..57c380eff6 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from os import getenv -from typing import List +from typing import Any, Dict, List import tensorrt_llm from tensorrt_llm import logger @@ -97,7 +97,7 @@ class KvCacheTransceiver(ABC): raise NotImplementedError @abstractmethod - def prepare_context_request(self, requests: List[LlmRequest]): + def prepare_context_requests(self, requests: List[LlmRequest]): """ Prepare the context request for the cache transceiver in generation-first mode. This method should set the context request state to DISAGG_CONTEXT_WAIT_SCHEDULER @@ -107,10 +107,10 @@ class KvCacheTransceiver(ABC): ... @abstractmethod - def get_context_state(self): + def get_disaggregated_params(self) -> Dict[str, Any]: """ - Return the opaque context request state, which will be attached to the generation request. - The generation server will use this state to get kvcache in generation-first mode. + Return a dictionary form of DisaggregatedParams to be set in the generation request. + The generation server will use it to get kvcache in generation-first mode. """ ... @@ -160,11 +160,13 @@ class BindKvCacheTransceiver(KvCacheTransceiver): def cancel_request(self, req: LlmRequest): return self.impl.cancel_request(req) - def prepare_context_request(self, requests: List[LlmRequest]): + def prepare_context_requests(self, requests: List[LlmRequest]): raise NotImplementedError - def get_context_state(self): - raise NotImplementedError + def get_disaggregated_params(self): + # Cpp kv cache transceiver will set the disaggregated params to context response + # Only new py cache transceiver will support gen-first disagg + return {} class CacheTransBufferManager: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 2c0593d651..e8be80483f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -13,6 +13,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from tensorrt_llm._torch.expert_statistic import ExpertStatistic +from tensorrt_llm.llmapi import DisaggScheduleStyle from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds try: @@ -1416,6 +1417,7 @@ class PyExecutor: return None, None if self.kv_cache_transceiver: + self._check_disagg_ctx_schedulable_status(new_requests) self._check_disagg_gen_transfer_status() self._check_kv_transfer_timeout() @@ -2394,6 +2396,24 @@ class PyExecutor: return + @nvtx_range("_check_disagg_ctx_schedulable_status") + def _check_disagg_ctx_schedulable_status(self, + new_requests: List[LlmRequest]): + """ + In context-first mode, context requests are scheduable immediately, + otherwise, we need to check if context requests are ready to be scheduled by querying kv cache transceiver + """ + if not self.kv_cache_transceiver: + return + ctx_only_requests = [ + req for req in new_requests + if req.is_context_only_request and req.py_disaggregated_params. + schedule_style == DisaggScheduleStyle.GENERATION_FIRST + ] + if ctx_only_requests: + self.kv_cache_transceiver.prepare_context_requests( + ctx_only_requests) + @nvtx_range("_pad_attention_dp_dummy_request") def _pad_attention_dp_dummy_request(self): """ diff --git a/tensorrt_llm/disaggregated_params.py b/tensorrt_llm/disaggregated_params.py index a394002582..0dc006e1f6 100644 --- a/tensorrt_llm/disaggregated_params.py +++ b/tensorrt_llm/disaggregated_params.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from enum import IntEnum from typing import Any, Dict, List, Optional import numpy as np @@ -11,6 +12,11 @@ import tensorrt as trt # noqa from tensorrt_llm.bindings import executor as tllme +class DisaggScheduleStyle(IntEnum): + CONTEXT_FIRST = 0 + GENERATION_FIRST = 1 + + @dataclass(slots=True, kw_only=True) class DisaggregatedParams: """Disaggregated serving parameters. @@ -38,6 +44,8 @@ class DisaggregatedParams: disagg_request_id: Optional[int] = None ctx_dp_rank: Optional[int] = None ctx_info_endpoint: Optional[List[str]] = None + schedule_style: Optional[DisaggScheduleStyle] = None + # E-P Disaggregated Params multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = ( None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index d121f520c3..04888a2ad5 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -654,6 +654,12 @@ class BaseWorker(GenerationExecutor): self.engine.shutdown() self.engine = None + def get_disaggregated_params(self) -> dict: + if self.engine is None or self.engine.kv_cache_transceiver is None: + logger.warning("Engine or kv cache transceiver is not initialized") + return {} + return self.engine.kv_cache_transceiver.get_disaggregated_params() + # Define a Callable to join iteration and request stats @staticmethod def _stats_serializer( diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index dc246f1114..12ffad5cd1 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -357,6 +357,9 @@ class GenerationExecutor(ABC): self._iter_kv_events_result.set_timeout(timeout) return self._iter_kv_events_result + def get_disaggregated_params(self) -> dict: + return {} + @staticmethod def _create_ray_executor( worker_kwargs: Dict, diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index 9271f51f55..b87b21f9f5 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -1,5 +1,5 @@ from .._torch.async_llm import AsyncLLM -from ..disaggregated_params import DisaggregatedParams +from ..disaggregated_params import DisaggregatedParams, DisaggScheduleStyle from ..executor import CompletionOutput, LoRARequest, RequestError from ..sampling_params import GuidedDecodingParams, SamplingParams from .build_cache import BuildCacheConfig @@ -32,6 +32,7 @@ __all__ = [ 'GuidedDecodingParams', 'SamplingParams', 'DisaggregatedParams', + 'DisaggScheduleStyle', 'KvCacheConfig', 'KvCacheRetentionConfig', 'CudaGraphConfig', diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py index d7f8fffe0d..ae6b7135bf 100644 --- a/tensorrt_llm/llmapi/disagg_utils.py +++ b/tensorrt_llm/llmapi/disagg_utils.py @@ -82,6 +82,8 @@ class DisaggServerConfig(): node_id: int = uuid.getnode( ) % 1021 # Assuming only one disagg-server is running on a machine, moding mac by the largest 10-bit prime # If this causes collisions, users can set node_id manually within range [0, 1023] in config + schedule_style: Literal['context_first', + 'generation_first'] = 'context_first' @dataclass diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index be37c89dc0..415751e7f2 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -104,6 +104,7 @@ TRT_LLM_DOCSTRING = TRT_LLMARGS_EXPLICIT_DOCSTRING + """ tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any. workspace (pathlib.Path): The directory to store intermediate files. llm_id (str): The unique ID of the LLM instance. + disaggregated_params (dict): The disaggregated parameters of the LLM instance. """ TORCH_LLM_DOCSTRING = TORCH_LLMARGS_EXPLICIT_DOCSTRING + """ @@ -111,6 +112,7 @@ TORCH_LLM_DOCSTRING = TORCH_LLMARGS_EXPLICIT_DOCSTRING + """ Attributes: tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any. llm_id (str): The unique ID of the LLM instance. + disaggregated_params (dict): The disaggregated parameters of the LLM instance. """ @@ -135,6 +137,7 @@ class BaseLLM: self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor) self._orchestrator_type = kwargs.get("orchestrator_type", None) self._llm_id = None + self._disaggregated_params = {} log_level = logger.level logger.set_level("info") # force display the backend @@ -263,6 +266,14 @@ class BaseLLM: return self._llm_id + @property + @set_api_status("beta") + def disaggregated_params(self) -> dict: + if self._disaggregated_params is None: + self._disaggregated_params = self._executor.get_disaggregated_params( + ) if self._executor else {} + return self._disaggregated_params + def generate( self, inputs: Union[PromptInputs, Sequence[PromptInputs]], diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 7639e405a5..06095da4c0 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -66,12 +66,10 @@ class RawRequestResponseHooks(ResponseHooks): def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse): self.ctx_server = ctx_server - logger.debug(f"Received context response from {ctx_server} for request {response.choices[0].disaggregated_params.ctx_request_id}") def on_first_token(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None): self.gen_server = gen_server self.server_first_token_time = get_steady_clock_now_in_seconds() - logger.debug(f"Received first token from {gen_server} for request {request.disaggregated_params.ctx_request_id}") def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None): if request.disaggregated_params: diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index 9f86974e50..4e322964ae 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -33,6 +33,7 @@ from tensorrt_llm.serve.openai_protocol import ( ChatCompletionRequest, CompletionRequest, DisaggregatedParams, + DisaggScheduleStyle, UCompletionRequest, UCompletionResponse, ) @@ -76,6 +77,15 @@ class OpenAIDisaggregatedService(OpenAIService): self._ctx_client = None self._gen_client = None self._disagg_cluster_manager = None + self._schedule_style = DisaggScheduleStyle.CONTEXT_FIRST + + match self._config.schedule_style: + case "generation_first": + self._send_disagg_request = self._send_disagg_request_gen_first + self._schedule_style = DisaggScheduleStyle.GENERATION_FIRST + case _: + self._send_disagg_request = self._send_disagg_request_ctx_first + self._schedule_style = DisaggScheduleStyle.CONTEXT_FIRST async def openai_completion( self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None @@ -102,7 +112,7 @@ class OpenAIDisaggregatedService(OpenAIService): raise RuntimeError("Cluster is not ready") return await self._send_disagg_request(request, hooks) - async def _send_disagg_request( + async def _send_disagg_request_ctx_first( self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None ) -> UCompletionResponseOrGenerator: if hooks: @@ -155,6 +165,7 @@ class OpenAIDisaggregatedService(OpenAIService): ), "stream": False, "stream_options": None, + "schedule_style": self._schedule_style, } ) return ctx_request @@ -162,16 +173,34 @@ class OpenAIDisaggregatedService(OpenAIService): def _get_gen_request( self, request: UCompletionRequest, - ctx_response: UCompletionResponse, + ctx_response: Optional[UCompletionResponse], disagg_request_id: Optional[int], + ctx_server_info: Optional[dict] = None, ) -> UCompletionRequest: - request.disaggregated_params = ctx_response.choices[0].disaggregated_params - request.disaggregated_params.request_type = "generation_only" - # Replace the string prompt with prompt_tokens_ids - if isinstance(request, CompletionRequest): - request.prompt = ctx_response.prompt_token_ids - elif isinstance(request, ChatCompletionRequest): - request.prompt_token_ids = ctx_response.prompt_token_ids + if ctx_response: + request.disaggregated_params = ctx_response.choices[0].disaggregated_params + request.disaggregated_params.request_type = "generation_only" + # Replace the string prompt with prompt_tokens_ids + if isinstance(request, CompletionRequest): + request.prompt = ctx_response.prompt_token_ids + elif isinstance(request, ChatCompletionRequest): + request.prompt_token_ids = ctx_response.prompt_token_ids + else: + # no ctx response, it's either a generation-only request or a generation-first disagg request + request.disaggregated_params = DisaggregatedParams( + request_type="generation_only", + ctx_request_id=disagg_request_id, + disagg_request_id=disagg_request_id, + schedule_style=self._schedule_style, + ) + if ctx_server_info and "server_info" in ctx_server_info: + disaggregated_params = ctx_server_info["server_info"].get("disaggregated_params", {}) + if disaggregated_params: + request.disaggregated_params = request.disaggregated_params.model_copy( + update=disaggregated_params + ) + + request.disaggregated_params.disagg_request_id = disagg_request_id return request async def _check_conditional_disagg(self, request: UCompletionRequest) -> bool: @@ -322,3 +351,37 @@ class OpenAIDisaggregatedService(OpenAIService): "Invalid disaggregated params in context phase response. disagg_request_id is None" ) return ctx_response + + async def _send_disagg_request_gen_first( + self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None + ) -> UCompletionResponse: + if hooks: + hooks.on_req_begin(request) + # empty server means client decides which server to use + need_ctx = not (await self._check_gen_only_disagg(request)) + ctx_server, gen_server = None, None + ctx_server_info = None + tasks = [] + ctx_req, gen_req = None, None + disagg_request_id = get_global_disagg_request_id(self._config.node_id) + if need_ctx: + ctx_server, ctx_server_info = await self._ctx_router.get_next_server(request) + ctx_req = self._get_ctx_request(request, disagg_request_id) + tasks.append( + asyncio.create_task( + self._ctx_client.send_request(ctx_req, server=ctx_server, hooks=hooks) + ) + ) + gen_req = self._get_gen_request( + request, + ctx_response=None, + disagg_request_id=disagg_request_id, + ctx_server_info=ctx_server_info, + ) + tasks.append( + asyncio.create_task( + self._gen_client.send_request(gen_req, server=gen_server, hooks=hooks) + ) + ) + responses = await asyncio.gather(*tasks) + return responses[-1] diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 877688732a..440958250b 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -36,7 +36,8 @@ from typing_extensions import Annotated, Required, TypeAlias, TypedDict from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams -from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams +from tensorrt_llm.llmapi import (DisaggScheduleStyle, GuidedDecodingParams, + SamplingParams) from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory @@ -121,6 +122,7 @@ class DisaggregatedParams(OpenAIBaseModel): disagg_request_id: Optional[int] = None ctx_dp_rank: Optional[int] = None ctx_info_endpoint: Optional[str] = None + schedule_style: Optional[DisaggScheduleStyle] = None class ErrorResponse(OpenAIBaseModel): @@ -1095,7 +1097,9 @@ def to_disaggregated_params( draft_tokens=tllm_disagg_params.draft_tokens, disagg_request_id=tllm_disagg_params.disagg_request_id, ctx_dp_rank=tllm_disagg_params.ctx_dp_rank, - ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint) + ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint, + schedule_style=tllm_disagg_params.schedule_style, + ) def to_llm_disaggregated_params( @@ -1111,7 +1115,9 @@ def to_llm_disaggregated_params( draft_tokens=disaggregated_params.draft_tokens, disagg_request_id=disaggregated_params.disagg_request_id, ctx_dp_rank=disaggregated_params.ctx_dp_rank, - ctx_info_endpoint=disaggregated_params.ctx_info_endpoint) + ctx_info_endpoint=disaggregated_params.ctx_info_endpoint, + schedule_style=disaggregated_params.schedule_style, + ) UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest] diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 9d3f60fcb0..53732ec852 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -290,6 +290,9 @@ class OpenAIServer: self.app.add_api_route("/update_weights", self.update_weights, methods=["POST"]) + self.app.add_api_route("/server_info", + self.get_server_info, + methods=["GET"]) if self.llm.args.return_perf_metrics: # register /prometheus/metrics self.mount_metrics() @@ -1131,6 +1134,9 @@ class OpenAIServer: await self.llm.collective_rpc('update_weights', args=(request.weights,)) return JSONResponse(content={"status": "success"}) + async def get_server_info(self) -> JSONResponse: + return JSONResponse(content={"disaggregated_params": self.llm.disaggregated_params}) + async def __call__(self, host, port, sockets: list[socket.socket] | None = None): # Store the binding address for server registration self.binding_addr = f"http://{host}:{port}" diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index a3bc0783f4..04336403ea 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -156,6 +156,7 @@ class Router(ABC): **kwargs): self._servers = servers or [] self._metadata_server = metadata_server + self._server_info: dict[str, dict] = {} self._server_role = server_role self._lock = asyncio.Lock() self._monitor_task = None @@ -176,10 +177,26 @@ class Router(ABC): def servers(self) -> List[str]: return self._servers + async def _fetch_server_info(self, server: str, timeout: float) -> dict: + session = aiohttp.ClientSession() + try: + async with session.get(f"http://{server}/server_info", + timeout=timeout) as response: + return await response.json() + except Exception as e: + logger.error(f"Error fetching server info for server {server}: {e}") + finally: + await session.close() + return {} + async def _prepare_server(self, server: str): if self._server_preparation_func: await self._server_preparation_func(server) + self._server_info[server] = await self._fetch_server_info( + server, self._health_check_timeout) + logger.info(f"server is ready with info: {self._server_info[server]}") + async def prepare_servers(self, servers: Optional[List[str]] = None): for server in servers or self._servers: await self._prepare_server(server) @@ -207,6 +224,7 @@ class Router(ABC): old_server for old_server in old_servers if old_server != server ] self._on_servers_updated(old_servers, self._servers) + self._server_info.pop(server, None) logger.debug( f"Removed server {server}, current server list: {self._servers}") @@ -471,7 +489,7 @@ class RoundRobinRouter(Router): raise ValueError( f"No available servers after excluding {exclude_server}" ) - return server, {} + return server, {"server_info": self._server_info.get(server, {})} async def finish_request(self, request: OpenAIRequest): pass @@ -558,7 +576,7 @@ class LoadBalancingRouter(Router): self._req_routing_table[id(request)] = server - return server, {} + return server, {"server_info": self._server_info.get(server, {})} def _get_server_load(self, server): return self._server_state[server]._num_active_tokens if self._use_tokens \ @@ -675,6 +693,7 @@ class KvCacheAwareRouter(Router): "block_hashes": block_hashes, # list[list[int]] "token_lists": token_lists, # list[list[int]] "matches": matches, # list[int] + "server_info": self._server_info.get(server, {}), } async def finish_request(self, diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 4288f8ef18..34ad12632e 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -31,6 +31,7 @@ l0_a10: - unittest/others/test_tracing.py - unittest/disaggregated/test_disagg_openai_client.py - unittest/disaggregated/test_disagg_utils.py + - unittest/disaggregated/test_openai_disagg_service.py - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py - unittest/disaggregated/test_agent_multi_backends.py diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index fab224c41c..3f7deb867f 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -311,3 +311,6 @@ properties: llm_id: annotation: str default: inspect._empty + disaggregated_params: + annotation: dict + default: None diff --git a/tests/unittest/disaggregated/test_openai_disagg_service.py b/tests/unittest/disaggregated/test_openai_disagg_service.py new file mode 100644 index 0000000000..b3c7582347 --- /dev/null +++ b/tests/unittest/disaggregated/test_openai_disagg_service.py @@ -0,0 +1,153 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from tensorrt_llm.llmapi.disagg_utils import DisaggServerConfig +from tensorrt_llm.serve.openai_disagg_service import OpenAIDisaggregatedService +from tensorrt_llm.serve.openai_protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DisaggregatedParams, + DisaggScheduleStyle, + UsageInfo, +) +from tensorrt_llm.serve.router import Router + + +def _client_factory(*_args, **_kwargs): + return AsyncMock() + + +def _make_service(schedule_style: str) -> OpenAIDisaggregatedService: + config = DisaggServerConfig(server_configs=[], schedule_style=schedule_style) + ctx_router = AsyncMock(spec=Router) + gen_router = AsyncMock(spec=Router) + return OpenAIDisaggregatedService( + config, ctx_router, gen_router, client_factory=_client_factory + ) + + +def _make_completion_response( + text: str, + finish_reason: str, + disagg_request_id: int = 42, + prompt_token_ids=None, + context_only=True, +) -> CompletionResponse: + if prompt_token_ids is None: + prompt_token_ids = [1, 2, 3] + return CompletionResponse( + model="test-model", + usage=UsageInfo(prompt_tokens=1, completion_tokens=1), + prompt_token_ids=prompt_token_ids, + choices=[ + CompletionResponseChoice( + index=0, + text=text, + finish_reason=finish_reason, + disaggregated_params=DisaggregatedParams( + request_type="context_only" if context_only else "generation_only", + disagg_request_id=disagg_request_id, + ctx_request_id=disagg_request_id, + ), + ) + ], + ) + + +async def _mock_streaming_response(chunks): + for chunk in chunks: + await asyncio.sleep(0) + yield chunk + + +@pytest.mark.asyncio +@pytest.mark.parametrize("stream", [False, True], ids=["non-streaming", "streaming"]) +@pytest.mark.parametrize("schedule_style", ["context_first", "generation_first"]) +async def test_send_disagg_request(monkeypatch, stream, schedule_style): + # simulate different order of ctx and gen responses in gen first mode + ctx_gen_delay = [ + (0.2, 0.1), + (0.2, 0.2), + (0.1, 0.2), + ] + monkeypatch.delenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY", raising=False) + service = _make_service(schedule_style) + if schedule_style == "context_first": + assert service._send_disagg_request == service._send_disagg_request_ctx_first + assert service._schedule_style == DisaggScheduleStyle.CONTEXT_FIRST + else: + assert service._send_disagg_request == service._send_disagg_request_gen_first + assert service._schedule_style == DisaggScheduleStyle.GENERATION_FIRST + opaque_state = "opaque_state" + for i, (ctx_delay, gen_delay) in enumerate(ctx_gen_delay): + stream_chunks = [b"data: gen-0\n\n", b"data: gen-1\n\n"] + service._ctx_client = AsyncMock() + service._gen_client = AsyncMock() + ctx_server_info = { + "server_info": {"disaggregated_params": {"encoded_opaque_state": opaque_state}} + } + service._ctx_router.get_next_server = AsyncMock(return_value=("ctx:9000", ctx_server_info)) + service._gen_router.get_next_server = AsyncMock( + return_value=("gen:9001", {"server_info": {}}) + ) + resp_text = f"response-{i}" + + async def _delayed_ctx_response(*_args, **_kwargs): + request = _args[0] + server, info = await service._ctx_router.get_next_server(request) + await asyncio.sleep(ctx_delay) + return _make_completion_response( + resp_text, + finish_reason="length", + disagg_request_id=request.disaggregated_params.disagg_request_id, + context_only=True, + ) + + async def _delayed_gen_response(*_args, **_kwargs): + request = _args[0] + server, info = await service._gen_router.get_next_server(request) + await asyncio.sleep(gen_delay) + if stream: + return _mock_streaming_response(stream_chunks) + return _make_completion_response( + resp_text, + finish_reason="stop", + disagg_request_id=request.disaggregated_params.disagg_request_id, + context_only=False, + ) + + service._ctx_client.send_request = AsyncMock(side_effect=_delayed_ctx_response) + service._gen_client.send_request = AsyncMock(side_effect=_delayed_gen_response) + + request = CompletionRequest(model="test-model", prompt="hello", stream=stream) + result = await service._send_disagg_request(request) + + ctx_req = service._ctx_client.send_request.call_args.args[0] + assert ctx_req.disaggregated_params.request_type == "context_only" + + gen_req = service._gen_client.send_request.call_args.args[0] + assert gen_req.disaggregated_params.request_type == "generation_only" + if schedule_style == "generation_first": + assert gen_req.disaggregated_params.encoded_opaque_state == opaque_state + assert ( + gen_req.disaggregated_params.ctx_request_id + == ctx_req.disaggregated_params.disagg_request_id + ) + + if stream: + assert hasattr(result, "__aiter__") + chunks = [chunk async for chunk in result] + assert chunks == stream_chunks + else: + assert result.model == "test-model" + assert result.usage.prompt_tokens == 1 + assert len(result.choices) == 1 + assert result.choices[0].text == resp_text + assert result.choices[0].finish_reason == "stop" + assert ( + result.choices[0].disaggregated_params.disagg_request_id + == ctx_req.disaggregated_params.disagg_request_id + )