mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-8921][feat] implement gen-first disagg_service (#11020)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
8f90330239
commit
f9c4bdf6cf
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from os import getenv
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import logger
|
||||
@ -97,7 +97,7 @@ class KvCacheTransceiver(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_context_request(self, requests: List[LlmRequest]):
|
||||
def prepare_context_requests(self, requests: List[LlmRequest]):
|
||||
"""
|
||||
Prepare the context request for the cache transceiver in generation-first mode.
|
||||
This method should set the context request state to DISAGG_CONTEXT_WAIT_SCHEDULER
|
||||
@ -107,10 +107,10 @@ class KvCacheTransceiver(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_context_state(self):
|
||||
def get_disaggregated_params(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the opaque context request state, which will be attached to the generation request.
|
||||
The generation server will use this state to get kvcache in generation-first mode.
|
||||
Return a dictionary form of DisaggregatedParams to be set in the generation request.
|
||||
The generation server will use it to get kvcache in generation-first mode.
|
||||
"""
|
||||
...
|
||||
|
||||
@ -160,11 +160,13 @@ class BindKvCacheTransceiver(KvCacheTransceiver):
|
||||
def cancel_request(self, req: LlmRequest):
|
||||
return self.impl.cancel_request(req)
|
||||
|
||||
def prepare_context_request(self, requests: List[LlmRequest]):
|
||||
def prepare_context_requests(self, requests: List[LlmRequest]):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_context_state(self):
|
||||
raise NotImplementedError
|
||||
def get_disaggregated_params(self):
|
||||
# Cpp kv cache transceiver will set the disaggregated params to context response
|
||||
# Only new py cache transceiver will support gen-first disagg
|
||||
return {}
|
||||
|
||||
|
||||
class CacheTransBufferManager:
|
||||
|
||||
@ -13,6 +13,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
|
||||
from tensorrt_llm.llmapi import DisaggScheduleStyle
|
||||
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
|
||||
|
||||
try:
|
||||
@ -1416,6 +1417,7 @@ class PyExecutor:
|
||||
return None, None
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._check_disagg_ctx_schedulable_status(new_requests)
|
||||
self._check_disagg_gen_transfer_status()
|
||||
self._check_kv_transfer_timeout()
|
||||
|
||||
@ -2394,6 +2396,24 @@ class PyExecutor:
|
||||
|
||||
return
|
||||
|
||||
@nvtx_range("_check_disagg_ctx_schedulable_status")
|
||||
def _check_disagg_ctx_schedulable_status(self,
|
||||
new_requests: List[LlmRequest]):
|
||||
"""
|
||||
In context-first mode, context requests are scheduable immediately,
|
||||
otherwise, we need to check if context requests are ready to be scheduled by querying kv cache transceiver
|
||||
"""
|
||||
if not self.kv_cache_transceiver:
|
||||
return
|
||||
ctx_only_requests = [
|
||||
req for req in new_requests
|
||||
if req.is_context_only_request and req.py_disaggregated_params.
|
||||
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
|
||||
]
|
||||
if ctx_only_requests:
|
||||
self.kv_cache_transceiver.prepare_context_requests(
|
||||
ctx_only_requests)
|
||||
|
||||
@nvtx_range("_pad_attention_dp_dummy_request")
|
||||
def _pad_attention_dp_dummy_request(self):
|
||||
"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
@ -11,6 +12,11 @@ import tensorrt as trt # noqa
|
||||
from tensorrt_llm.bindings import executor as tllme
|
||||
|
||||
|
||||
class DisaggScheduleStyle(IntEnum):
|
||||
CONTEXT_FIRST = 0
|
||||
GENERATION_FIRST = 1
|
||||
|
||||
|
||||
@dataclass(slots=True, kw_only=True)
|
||||
class DisaggregatedParams:
|
||||
"""Disaggregated serving parameters.
|
||||
@ -38,6 +44,8 @@ class DisaggregatedParams:
|
||||
disagg_request_id: Optional[int] = None
|
||||
ctx_dp_rank: Optional[int] = None
|
||||
ctx_info_endpoint: Optional[List[str]] = None
|
||||
schedule_style: Optional[DisaggScheduleStyle] = None
|
||||
|
||||
# E-P Disaggregated Params
|
||||
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
|
||||
None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding
|
||||
|
||||
@ -654,6 +654,12 @@ class BaseWorker(GenerationExecutor):
|
||||
self.engine.shutdown()
|
||||
self.engine = None
|
||||
|
||||
def get_disaggregated_params(self) -> dict:
|
||||
if self.engine is None or self.engine.kv_cache_transceiver is None:
|
||||
logger.warning("Engine or kv cache transceiver is not initialized")
|
||||
return {}
|
||||
return self.engine.kv_cache_transceiver.get_disaggregated_params()
|
||||
|
||||
# Define a Callable to join iteration and request stats
|
||||
@staticmethod
|
||||
def _stats_serializer(
|
||||
|
||||
@ -357,6 +357,9 @@ class GenerationExecutor(ABC):
|
||||
self._iter_kv_events_result.set_timeout(timeout)
|
||||
return self._iter_kv_events_result
|
||||
|
||||
def get_disaggregated_params(self) -> dict:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _create_ray_executor(
|
||||
worker_kwargs: Dict,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from .._torch.async_llm import AsyncLLM
|
||||
from ..disaggregated_params import DisaggregatedParams
|
||||
from ..disaggregated_params import DisaggregatedParams, DisaggScheduleStyle
|
||||
from ..executor import CompletionOutput, LoRARequest, RequestError
|
||||
from ..sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from .build_cache import BuildCacheConfig
|
||||
@ -32,6 +32,7 @@ __all__ = [
|
||||
'GuidedDecodingParams',
|
||||
'SamplingParams',
|
||||
'DisaggregatedParams',
|
||||
'DisaggScheduleStyle',
|
||||
'KvCacheConfig',
|
||||
'KvCacheRetentionConfig',
|
||||
'CudaGraphConfig',
|
||||
|
||||
@ -82,6 +82,8 @@ class DisaggServerConfig():
|
||||
node_id: int = uuid.getnode(
|
||||
) % 1021 # Assuming only one disagg-server is running on a machine, moding mac by the largest 10-bit prime
|
||||
# If this causes collisions, users can set node_id manually within range [0, 1023] in config
|
||||
schedule_style: Literal['context_first',
|
||||
'generation_first'] = 'context_first'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -104,6 +104,7 @@ TRT_LLM_DOCSTRING = TRT_LLMARGS_EXPLICIT_DOCSTRING + """
|
||||
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
|
||||
workspace (pathlib.Path): The directory to store intermediate files.
|
||||
llm_id (str): The unique ID of the LLM instance.
|
||||
disaggregated_params (dict): The disaggregated parameters of the LLM instance.
|
||||
"""
|
||||
|
||||
TORCH_LLM_DOCSTRING = TORCH_LLMARGS_EXPLICIT_DOCSTRING + """
|
||||
@ -111,6 +112,7 @@ TORCH_LLM_DOCSTRING = TORCH_LLMARGS_EXPLICIT_DOCSTRING + """
|
||||
Attributes:
|
||||
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
|
||||
llm_id (str): The unique ID of the LLM instance.
|
||||
disaggregated_params (dict): The disaggregated parameters of the LLM instance.
|
||||
"""
|
||||
|
||||
|
||||
@ -135,6 +137,7 @@ class BaseLLM:
|
||||
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
|
||||
self._orchestrator_type = kwargs.get("orchestrator_type", None)
|
||||
self._llm_id = None
|
||||
self._disaggregated_params = {}
|
||||
|
||||
log_level = logger.level
|
||||
logger.set_level("info") # force display the backend
|
||||
@ -263,6 +266,14 @@ class BaseLLM:
|
||||
|
||||
return self._llm_id
|
||||
|
||||
@property
|
||||
@set_api_status("beta")
|
||||
def disaggregated_params(self) -> dict:
|
||||
if self._disaggregated_params is None:
|
||||
self._disaggregated_params = self._executor.get_disaggregated_params(
|
||||
) if self._executor else {}
|
||||
return self._disaggregated_params
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
|
||||
@ -66,12 +66,10 @@ class RawRequestResponseHooks(ResponseHooks):
|
||||
|
||||
def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse):
|
||||
self.ctx_server = ctx_server
|
||||
logger.debug(f"Received context response from {ctx_server} for request {response.choices[0].disaggregated_params.ctx_request_id}")
|
||||
|
||||
def on_first_token(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
|
||||
self.gen_server = gen_server
|
||||
self.server_first_token_time = get_steady_clock_now_in_seconds()
|
||||
logger.debug(f"Received first token from {gen_server} for request {request.disaggregated_params.ctx_request_id}")
|
||||
|
||||
def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
|
||||
if request.disaggregated_params:
|
||||
|
||||
@ -33,6 +33,7 @@ from tensorrt_llm.serve.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DisaggregatedParams,
|
||||
DisaggScheduleStyle,
|
||||
UCompletionRequest,
|
||||
UCompletionResponse,
|
||||
)
|
||||
@ -76,6 +77,15 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
self._ctx_client = None
|
||||
self._gen_client = None
|
||||
self._disagg_cluster_manager = None
|
||||
self._schedule_style = DisaggScheduleStyle.CONTEXT_FIRST
|
||||
|
||||
match self._config.schedule_style:
|
||||
case "generation_first":
|
||||
self._send_disagg_request = self._send_disagg_request_gen_first
|
||||
self._schedule_style = DisaggScheduleStyle.GENERATION_FIRST
|
||||
case _:
|
||||
self._send_disagg_request = self._send_disagg_request_ctx_first
|
||||
self._schedule_style = DisaggScheduleStyle.CONTEXT_FIRST
|
||||
|
||||
async def openai_completion(
|
||||
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
|
||||
@ -102,7 +112,7 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
raise RuntimeError("Cluster is not ready")
|
||||
return await self._send_disagg_request(request, hooks)
|
||||
|
||||
async def _send_disagg_request(
|
||||
async def _send_disagg_request_ctx_first(
|
||||
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
|
||||
) -> UCompletionResponseOrGenerator:
|
||||
if hooks:
|
||||
@ -155,6 +165,7 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
),
|
||||
"stream": False,
|
||||
"stream_options": None,
|
||||
"schedule_style": self._schedule_style,
|
||||
}
|
||||
)
|
||||
return ctx_request
|
||||
@ -162,16 +173,34 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
def _get_gen_request(
|
||||
self,
|
||||
request: UCompletionRequest,
|
||||
ctx_response: UCompletionResponse,
|
||||
ctx_response: Optional[UCompletionResponse],
|
||||
disagg_request_id: Optional[int],
|
||||
ctx_server_info: Optional[dict] = None,
|
||||
) -> UCompletionRequest:
|
||||
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
|
||||
request.disaggregated_params.request_type = "generation_only"
|
||||
# Replace the string prompt with prompt_tokens_ids
|
||||
if isinstance(request, CompletionRequest):
|
||||
request.prompt = ctx_response.prompt_token_ids
|
||||
elif isinstance(request, ChatCompletionRequest):
|
||||
request.prompt_token_ids = ctx_response.prompt_token_ids
|
||||
if ctx_response:
|
||||
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
|
||||
request.disaggregated_params.request_type = "generation_only"
|
||||
# Replace the string prompt with prompt_tokens_ids
|
||||
if isinstance(request, CompletionRequest):
|
||||
request.prompt = ctx_response.prompt_token_ids
|
||||
elif isinstance(request, ChatCompletionRequest):
|
||||
request.prompt_token_ids = ctx_response.prompt_token_ids
|
||||
else:
|
||||
# no ctx response, it's either a generation-only request or a generation-first disagg request
|
||||
request.disaggregated_params = DisaggregatedParams(
|
||||
request_type="generation_only",
|
||||
ctx_request_id=disagg_request_id,
|
||||
disagg_request_id=disagg_request_id,
|
||||
schedule_style=self._schedule_style,
|
||||
)
|
||||
if ctx_server_info and "server_info" in ctx_server_info:
|
||||
disaggregated_params = ctx_server_info["server_info"].get("disaggregated_params", {})
|
||||
if disaggregated_params:
|
||||
request.disaggregated_params = request.disaggregated_params.model_copy(
|
||||
update=disaggregated_params
|
||||
)
|
||||
|
||||
request.disaggregated_params.disagg_request_id = disagg_request_id
|
||||
return request
|
||||
|
||||
async def _check_conditional_disagg(self, request: UCompletionRequest) -> bool:
|
||||
@ -322,3 +351,37 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
"Invalid disaggregated params in context phase response. disagg_request_id is None"
|
||||
)
|
||||
return ctx_response
|
||||
|
||||
async def _send_disagg_request_gen_first(
|
||||
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
|
||||
) -> UCompletionResponse:
|
||||
if hooks:
|
||||
hooks.on_req_begin(request)
|
||||
# empty server means client decides which server to use
|
||||
need_ctx = not (await self._check_gen_only_disagg(request))
|
||||
ctx_server, gen_server = None, None
|
||||
ctx_server_info = None
|
||||
tasks = []
|
||||
ctx_req, gen_req = None, None
|
||||
disagg_request_id = get_global_disagg_request_id(self._config.node_id)
|
||||
if need_ctx:
|
||||
ctx_server, ctx_server_info = await self._ctx_router.get_next_server(request)
|
||||
ctx_req = self._get_ctx_request(request, disagg_request_id)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
self._ctx_client.send_request(ctx_req, server=ctx_server, hooks=hooks)
|
||||
)
|
||||
)
|
||||
gen_req = self._get_gen_request(
|
||||
request,
|
||||
ctx_response=None,
|
||||
disagg_request_id=disagg_request_id,
|
||||
ctx_server_info=ctx_server_info,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
self._gen_client.send_request(gen_req, server=gen_server, hooks=hooks)
|
||||
)
|
||||
)
|
||||
responses = await asyncio.gather(*tasks)
|
||||
return responses[-1]
|
||||
|
||||
@ -36,7 +36,8 @@ from typing_extensions import Annotated, Required, TypeAlias, TypedDict
|
||||
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
|
||||
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
|
||||
from tensorrt_llm.llmapi import (DisaggScheduleStyle, GuidedDecodingParams,
|
||||
SamplingParams)
|
||||
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
|
||||
|
||||
|
||||
@ -121,6 +122,7 @@ class DisaggregatedParams(OpenAIBaseModel):
|
||||
disagg_request_id: Optional[int] = None
|
||||
ctx_dp_rank: Optional[int] = None
|
||||
ctx_info_endpoint: Optional[str] = None
|
||||
schedule_style: Optional[DisaggScheduleStyle] = None
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
@ -1095,7 +1097,9 @@ def to_disaggregated_params(
|
||||
draft_tokens=tllm_disagg_params.draft_tokens,
|
||||
disagg_request_id=tllm_disagg_params.disagg_request_id,
|
||||
ctx_dp_rank=tllm_disagg_params.ctx_dp_rank,
|
||||
ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint)
|
||||
ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint,
|
||||
schedule_style=tllm_disagg_params.schedule_style,
|
||||
)
|
||||
|
||||
|
||||
def to_llm_disaggregated_params(
|
||||
@ -1111,7 +1115,9 @@ def to_llm_disaggregated_params(
|
||||
draft_tokens=disaggregated_params.draft_tokens,
|
||||
disagg_request_id=disaggregated_params.disagg_request_id,
|
||||
ctx_dp_rank=disaggregated_params.ctx_dp_rank,
|
||||
ctx_info_endpoint=disaggregated_params.ctx_info_endpoint)
|
||||
ctx_info_endpoint=disaggregated_params.ctx_info_endpoint,
|
||||
schedule_style=disaggregated_params.schedule_style,
|
||||
)
|
||||
|
||||
|
||||
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]
|
||||
|
||||
@ -290,6 +290,9 @@ class OpenAIServer:
|
||||
self.app.add_api_route("/update_weights",
|
||||
self.update_weights,
|
||||
methods=["POST"])
|
||||
self.app.add_api_route("/server_info",
|
||||
self.get_server_info,
|
||||
methods=["GET"])
|
||||
if self.llm.args.return_perf_metrics:
|
||||
# register /prometheus/metrics
|
||||
self.mount_metrics()
|
||||
@ -1131,6 +1134,9 @@ class OpenAIServer:
|
||||
await self.llm.collective_rpc('update_weights', args=(request.weights,))
|
||||
return JSONResponse(content={"status": "success"})
|
||||
|
||||
async def get_server_info(self) -> JSONResponse:
|
||||
return JSONResponse(content={"disaggregated_params": self.llm.disaggregated_params})
|
||||
|
||||
async def __call__(self, host, port, sockets: list[socket.socket] | None = None):
|
||||
# Store the binding address for server registration
|
||||
self.binding_addr = f"http://{host}:{port}"
|
||||
|
||||
@ -156,6 +156,7 @@ class Router(ABC):
|
||||
**kwargs):
|
||||
self._servers = servers or []
|
||||
self._metadata_server = metadata_server
|
||||
self._server_info: dict[str, dict] = {}
|
||||
self._server_role = server_role
|
||||
self._lock = asyncio.Lock()
|
||||
self._monitor_task = None
|
||||
@ -176,10 +177,26 @@ class Router(ABC):
|
||||
def servers(self) -> List[str]:
|
||||
return self._servers
|
||||
|
||||
async def _fetch_server_info(self, server: str, timeout: float) -> dict:
|
||||
session = aiohttp.ClientSession()
|
||||
try:
|
||||
async with session.get(f"http://{server}/server_info",
|
||||
timeout=timeout) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching server info for server {server}: {e}")
|
||||
finally:
|
||||
await session.close()
|
||||
return {}
|
||||
|
||||
async def _prepare_server(self, server: str):
|
||||
if self._server_preparation_func:
|
||||
await self._server_preparation_func(server)
|
||||
|
||||
self._server_info[server] = await self._fetch_server_info(
|
||||
server, self._health_check_timeout)
|
||||
logger.info(f"server is ready with info: {self._server_info[server]}")
|
||||
|
||||
async def prepare_servers(self, servers: Optional[List[str]] = None):
|
||||
for server in servers or self._servers:
|
||||
await self._prepare_server(server)
|
||||
@ -207,6 +224,7 @@ class Router(ABC):
|
||||
old_server for old_server in old_servers if old_server != server
|
||||
]
|
||||
self._on_servers_updated(old_servers, self._servers)
|
||||
self._server_info.pop(server, None)
|
||||
logger.debug(
|
||||
f"Removed server {server}, current server list: {self._servers}")
|
||||
|
||||
@ -471,7 +489,7 @@ class RoundRobinRouter(Router):
|
||||
raise ValueError(
|
||||
f"No available servers after excluding {exclude_server}"
|
||||
)
|
||||
return server, {}
|
||||
return server, {"server_info": self._server_info.get(server, {})}
|
||||
|
||||
async def finish_request(self, request: OpenAIRequest):
|
||||
pass
|
||||
@ -558,7 +576,7 @@ class LoadBalancingRouter(Router):
|
||||
|
||||
self._req_routing_table[id(request)] = server
|
||||
|
||||
return server, {}
|
||||
return server, {"server_info": self._server_info.get(server, {})}
|
||||
|
||||
def _get_server_load(self, server):
|
||||
return self._server_state[server]._num_active_tokens if self._use_tokens \
|
||||
@ -675,6 +693,7 @@ class KvCacheAwareRouter(Router):
|
||||
"block_hashes": block_hashes, # list[list[int]]
|
||||
"token_lists": token_lists, # list[list[int]]
|
||||
"matches": matches, # list[int]
|
||||
"server_info": self._server_info.get(server, {}),
|
||||
}
|
||||
|
||||
async def finish_request(self,
|
||||
|
||||
@ -31,6 +31,7 @@ l0_a10:
|
||||
- unittest/others/test_tracing.py
|
||||
- unittest/disaggregated/test_disagg_openai_client.py
|
||||
- unittest/disaggregated/test_disagg_utils.py
|
||||
- unittest/disaggregated/test_openai_disagg_service.py
|
||||
- unittest/disaggregated/test_router.py
|
||||
- unittest/disaggregated/test_remoteDictionary.py
|
||||
- unittest/disaggregated/test_agent_multi_backends.py
|
||||
|
||||
@ -311,3 +311,6 @@ properties:
|
||||
llm_id:
|
||||
annotation: str
|
||||
default: inspect._empty
|
||||
disaggregated_params:
|
||||
annotation: dict
|
||||
default: None
|
||||
|
||||
153
tests/unittest/disaggregated/test_openai_disagg_service.py
Normal file
153
tests/unittest/disaggregated/test_openai_disagg_service.py
Normal file
@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm.llmapi.disagg_utils import DisaggServerConfig
|
||||
from tensorrt_llm.serve.openai_disagg_service import OpenAIDisaggregatedService
|
||||
from tensorrt_llm.serve.openai_protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
DisaggregatedParams,
|
||||
DisaggScheduleStyle,
|
||||
UsageInfo,
|
||||
)
|
||||
from tensorrt_llm.serve.router import Router
|
||||
|
||||
|
||||
def _client_factory(*_args, **_kwargs):
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
def _make_service(schedule_style: str) -> OpenAIDisaggregatedService:
|
||||
config = DisaggServerConfig(server_configs=[], schedule_style=schedule_style)
|
||||
ctx_router = AsyncMock(spec=Router)
|
||||
gen_router = AsyncMock(spec=Router)
|
||||
return OpenAIDisaggregatedService(
|
||||
config, ctx_router, gen_router, client_factory=_client_factory
|
||||
)
|
||||
|
||||
|
||||
def _make_completion_response(
|
||||
text: str,
|
||||
finish_reason: str,
|
||||
disagg_request_id: int = 42,
|
||||
prompt_token_ids=None,
|
||||
context_only=True,
|
||||
) -> CompletionResponse:
|
||||
if prompt_token_ids is None:
|
||||
prompt_token_ids = [1, 2, 3]
|
||||
return CompletionResponse(
|
||||
model="test-model",
|
||||
usage=UsageInfo(prompt_tokens=1, completion_tokens=1),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
choices=[
|
||||
CompletionResponseChoice(
|
||||
index=0,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
disaggregated_params=DisaggregatedParams(
|
||||
request_type="context_only" if context_only else "generation_only",
|
||||
disagg_request_id=disagg_request_id,
|
||||
ctx_request_id=disagg_request_id,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def _mock_streaming_response(chunks):
|
||||
for chunk in chunks:
|
||||
await asyncio.sleep(0)
|
||||
yield chunk
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("stream", [False, True], ids=["non-streaming", "streaming"])
|
||||
@pytest.mark.parametrize("schedule_style", ["context_first", "generation_first"])
|
||||
async def test_send_disagg_request(monkeypatch, stream, schedule_style):
|
||||
# simulate different order of ctx and gen responses in gen first mode
|
||||
ctx_gen_delay = [
|
||||
(0.2, 0.1),
|
||||
(0.2, 0.2),
|
||||
(0.1, 0.2),
|
||||
]
|
||||
monkeypatch.delenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY", raising=False)
|
||||
service = _make_service(schedule_style)
|
||||
if schedule_style == "context_first":
|
||||
assert service._send_disagg_request == service._send_disagg_request_ctx_first
|
||||
assert service._schedule_style == DisaggScheduleStyle.CONTEXT_FIRST
|
||||
else:
|
||||
assert service._send_disagg_request == service._send_disagg_request_gen_first
|
||||
assert service._schedule_style == DisaggScheduleStyle.GENERATION_FIRST
|
||||
opaque_state = "opaque_state"
|
||||
for i, (ctx_delay, gen_delay) in enumerate(ctx_gen_delay):
|
||||
stream_chunks = [b"data: gen-0\n\n", b"data: gen-1\n\n"]
|
||||
service._ctx_client = AsyncMock()
|
||||
service._gen_client = AsyncMock()
|
||||
ctx_server_info = {
|
||||
"server_info": {"disaggregated_params": {"encoded_opaque_state": opaque_state}}
|
||||
}
|
||||
service._ctx_router.get_next_server = AsyncMock(return_value=("ctx:9000", ctx_server_info))
|
||||
service._gen_router.get_next_server = AsyncMock(
|
||||
return_value=("gen:9001", {"server_info": {}})
|
||||
)
|
||||
resp_text = f"response-{i}"
|
||||
|
||||
async def _delayed_ctx_response(*_args, **_kwargs):
|
||||
request = _args[0]
|
||||
server, info = await service._ctx_router.get_next_server(request)
|
||||
await asyncio.sleep(ctx_delay)
|
||||
return _make_completion_response(
|
||||
resp_text,
|
||||
finish_reason="length",
|
||||
disagg_request_id=request.disaggregated_params.disagg_request_id,
|
||||
context_only=True,
|
||||
)
|
||||
|
||||
async def _delayed_gen_response(*_args, **_kwargs):
|
||||
request = _args[0]
|
||||
server, info = await service._gen_router.get_next_server(request)
|
||||
await asyncio.sleep(gen_delay)
|
||||
if stream:
|
||||
return _mock_streaming_response(stream_chunks)
|
||||
return _make_completion_response(
|
||||
resp_text,
|
||||
finish_reason="stop",
|
||||
disagg_request_id=request.disaggregated_params.disagg_request_id,
|
||||
context_only=False,
|
||||
)
|
||||
|
||||
service._ctx_client.send_request = AsyncMock(side_effect=_delayed_ctx_response)
|
||||
service._gen_client.send_request = AsyncMock(side_effect=_delayed_gen_response)
|
||||
|
||||
request = CompletionRequest(model="test-model", prompt="hello", stream=stream)
|
||||
result = await service._send_disagg_request(request)
|
||||
|
||||
ctx_req = service._ctx_client.send_request.call_args.args[0]
|
||||
assert ctx_req.disaggregated_params.request_type == "context_only"
|
||||
|
||||
gen_req = service._gen_client.send_request.call_args.args[0]
|
||||
assert gen_req.disaggregated_params.request_type == "generation_only"
|
||||
if schedule_style == "generation_first":
|
||||
assert gen_req.disaggregated_params.encoded_opaque_state == opaque_state
|
||||
assert (
|
||||
gen_req.disaggregated_params.ctx_request_id
|
||||
== ctx_req.disaggregated_params.disagg_request_id
|
||||
)
|
||||
|
||||
if stream:
|
||||
assert hasattr(result, "__aiter__")
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert chunks == stream_chunks
|
||||
else:
|
||||
assert result.model == "test-model"
|
||||
assert result.usage.prompt_tokens == 1
|
||||
assert len(result.choices) == 1
|
||||
assert result.choices[0].text == resp_text
|
||||
assert result.choices[0].finish_reason == "stop"
|
||||
assert (
|
||||
result.choices[0].disaggregated_params.disagg_request_id
|
||||
== ctx_req.disaggregated_params.disagg_request_id
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user