[TRTLLM-8921][feat] implement gen-first disagg_service (#11020)

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2026-02-04 04:46:11 +08:00 committed by GitHub
parent 8f90330239
commit f9c4bdf6cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 327 additions and 25 deletions

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from os import getenv
from typing import List
from typing import Any, Dict, List
import tensorrt_llm
from tensorrt_llm import logger
@ -97,7 +97,7 @@ class KvCacheTransceiver(ABC):
raise NotImplementedError
@abstractmethod
def prepare_context_request(self, requests: List[LlmRequest]):
def prepare_context_requests(self, requests: List[LlmRequest]):
"""
Prepare the context request for the cache transceiver in generation-first mode.
This method should set the context request state to DISAGG_CONTEXT_WAIT_SCHEDULER
@ -107,10 +107,10 @@ class KvCacheTransceiver(ABC):
...
@abstractmethod
def get_context_state(self):
def get_disaggregated_params(self) -> Dict[str, Any]:
"""
Return the opaque context request state, which will be attached to the generation request.
The generation server will use this state to get kvcache in generation-first mode.
Return a dictionary form of DisaggregatedParams to be set in the generation request.
The generation server will use it to get kvcache in generation-first mode.
"""
...
@ -160,11 +160,13 @@ class BindKvCacheTransceiver(KvCacheTransceiver):
def cancel_request(self, req: LlmRequest):
return self.impl.cancel_request(req)
def prepare_context_request(self, requests: List[LlmRequest]):
def prepare_context_requests(self, requests: List[LlmRequest]):
raise NotImplementedError
def get_context_state(self):
raise NotImplementedError
def get_disaggregated_params(self):
# Cpp kv cache transceiver will set the disaggregated params to context response
# Only new py cache transceiver will support gen-first disagg
return {}
class CacheTransBufferManager:

View File

@ -13,6 +13,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
from tensorrt_llm.llmapi import DisaggScheduleStyle
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
try:
@ -1416,6 +1417,7 @@ class PyExecutor:
return None, None
if self.kv_cache_transceiver:
self._check_disagg_ctx_schedulable_status(new_requests)
self._check_disagg_gen_transfer_status()
self._check_kv_transfer_timeout()
@ -2394,6 +2396,24 @@ class PyExecutor:
return
@nvtx_range("_check_disagg_ctx_schedulable_status")
def _check_disagg_ctx_schedulable_status(self,
new_requests: List[LlmRequest]):
"""
In context-first mode, context requests are scheduable immediately,
otherwise, we need to check if context requests are ready to be scheduled by querying kv cache transceiver
"""
if not self.kv_cache_transceiver:
return
ctx_only_requests = [
req for req in new_requests
if req.is_context_only_request and req.py_disaggregated_params.
schedule_style == DisaggScheduleStyle.GENERATION_FIRST
]
if ctx_only_requests:
self.kv_cache_transceiver.prepare_context_requests(
ctx_only_requests)
@nvtx_range("_pad_attention_dp_dummy_request")
def _pad_attention_dp_dummy_request(self):
"""

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Dict, List, Optional
import numpy as np
@ -11,6 +12,11 @@ import tensorrt as trt # noqa
from tensorrt_llm.bindings import executor as tllme
class DisaggScheduleStyle(IntEnum):
CONTEXT_FIRST = 0
GENERATION_FIRST = 1
@dataclass(slots=True, kw_only=True)
class DisaggregatedParams:
"""Disaggregated serving parameters.
@ -38,6 +44,8 @@ class DisaggregatedParams:
disagg_request_id: Optional[int] = None
ctx_dp_rank: Optional[int] = None
ctx_info_endpoint: Optional[List[str]] = None
schedule_style: Optional[DisaggScheduleStyle] = None
# E-P Disaggregated Params
multimodal_embedding_handles: Optional[List[Dict[str, Any]]] = (
None # multimodal embedding handles should be a list of cudaIPC handles for each mm_embedding

View File

@ -654,6 +654,12 @@ class BaseWorker(GenerationExecutor):
self.engine.shutdown()
self.engine = None
def get_disaggregated_params(self) -> dict:
if self.engine is None or self.engine.kv_cache_transceiver is None:
logger.warning("Engine or kv cache transceiver is not initialized")
return {}
return self.engine.kv_cache_transceiver.get_disaggregated_params()
# Define a Callable to join iteration and request stats
@staticmethod
def _stats_serializer(

View File

@ -357,6 +357,9 @@ class GenerationExecutor(ABC):
self._iter_kv_events_result.set_timeout(timeout)
return self._iter_kv_events_result
def get_disaggregated_params(self) -> dict:
return {}
@staticmethod
def _create_ray_executor(
worker_kwargs: Dict,

View File

@ -1,5 +1,5 @@
from .._torch.async_llm import AsyncLLM
from ..disaggregated_params import DisaggregatedParams
from ..disaggregated_params import DisaggregatedParams, DisaggScheduleStyle
from ..executor import CompletionOutput, LoRARequest, RequestError
from ..sampling_params import GuidedDecodingParams, SamplingParams
from .build_cache import BuildCacheConfig
@ -32,6 +32,7 @@ __all__ = [
'GuidedDecodingParams',
'SamplingParams',
'DisaggregatedParams',
'DisaggScheduleStyle',
'KvCacheConfig',
'KvCacheRetentionConfig',
'CudaGraphConfig',

View File

@ -82,6 +82,8 @@ class DisaggServerConfig():
node_id: int = uuid.getnode(
) % 1021 # Assuming only one disagg-server is running on a machine, moding mac by the largest 10-bit prime
# If this causes collisions, users can set node_id manually within range [0, 1023] in config
schedule_style: Literal['context_first',
'generation_first'] = 'context_first'
@dataclass

View File

@ -104,6 +104,7 @@ TRT_LLM_DOCSTRING = TRT_LLMARGS_EXPLICIT_DOCSTRING + """
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
workspace (pathlib.Path): The directory to store intermediate files.
llm_id (str): The unique ID of the LLM instance.
disaggregated_params (dict): The disaggregated parameters of the LLM instance.
"""
TORCH_LLM_DOCSTRING = TORCH_LLMARGS_EXPLICIT_DOCSTRING + """
@ -111,6 +112,7 @@ TORCH_LLM_DOCSTRING = TORCH_LLMARGS_EXPLICIT_DOCSTRING + """
Attributes:
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
llm_id (str): The unique ID of the LLM instance.
disaggregated_params (dict): The disaggregated parameters of the LLM instance.
"""
@ -135,6 +137,7 @@ class BaseLLM:
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
self._orchestrator_type = kwargs.get("orchestrator_type", None)
self._llm_id = None
self._disaggregated_params = {}
log_level = logger.level
logger.set_level("info") # force display the backend
@ -263,6 +266,14 @@ class BaseLLM:
return self._llm_id
@property
@set_api_status("beta")
def disaggregated_params(self) -> dict:
if self._disaggregated_params is None:
self._disaggregated_params = self._executor.get_disaggregated_params(
) if self._executor else {}
return self._disaggregated_params
def generate(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],

View File

@ -66,12 +66,10 @@ class RawRequestResponseHooks(ResponseHooks):
def on_ctx_resp(self, ctx_server: str, response: UCompletionResponse):
self.ctx_server = ctx_server
logger.debug(f"Received context response from {ctx_server} for request {response.choices[0].disaggregated_params.ctx_request_id}")
def on_first_token(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
self.gen_server = gen_server
self.server_first_token_time = get_steady_clock_now_in_seconds()
logger.debug(f"Received first token from {gen_server} for request {request.disaggregated_params.ctx_request_id}")
def on_resp_done(self, gen_server: str, request: UCompletionRequest, response: UCompletionResponse = None):
if request.disaggregated_params:

View File

@ -33,6 +33,7 @@ from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
CompletionRequest,
DisaggregatedParams,
DisaggScheduleStyle,
UCompletionRequest,
UCompletionResponse,
)
@ -76,6 +77,15 @@ class OpenAIDisaggregatedService(OpenAIService):
self._ctx_client = None
self._gen_client = None
self._disagg_cluster_manager = None
self._schedule_style = DisaggScheduleStyle.CONTEXT_FIRST
match self._config.schedule_style:
case "generation_first":
self._send_disagg_request = self._send_disagg_request_gen_first
self._schedule_style = DisaggScheduleStyle.GENERATION_FIRST
case _:
self._send_disagg_request = self._send_disagg_request_ctx_first
self._schedule_style = DisaggScheduleStyle.CONTEXT_FIRST
async def openai_completion(
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
@ -102,7 +112,7 @@ class OpenAIDisaggregatedService(OpenAIService):
raise RuntimeError("Cluster is not ready")
return await self._send_disagg_request(request, hooks)
async def _send_disagg_request(
async def _send_disagg_request_ctx_first(
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
) -> UCompletionResponseOrGenerator:
if hooks:
@ -155,6 +165,7 @@ class OpenAIDisaggregatedService(OpenAIService):
),
"stream": False,
"stream_options": None,
"schedule_style": self._schedule_style,
}
)
return ctx_request
@ -162,16 +173,34 @@ class OpenAIDisaggregatedService(OpenAIService):
def _get_gen_request(
self,
request: UCompletionRequest,
ctx_response: UCompletionResponse,
ctx_response: Optional[UCompletionResponse],
disagg_request_id: Optional[int],
ctx_server_info: Optional[dict] = None,
) -> UCompletionRequest:
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
request.disaggregated_params.request_type = "generation_only"
# Replace the string prompt with prompt_tokens_ids
if isinstance(request, CompletionRequest):
request.prompt = ctx_response.prompt_token_ids
elif isinstance(request, ChatCompletionRequest):
request.prompt_token_ids = ctx_response.prompt_token_ids
if ctx_response:
request.disaggregated_params = ctx_response.choices[0].disaggregated_params
request.disaggregated_params.request_type = "generation_only"
# Replace the string prompt with prompt_tokens_ids
if isinstance(request, CompletionRequest):
request.prompt = ctx_response.prompt_token_ids
elif isinstance(request, ChatCompletionRequest):
request.prompt_token_ids = ctx_response.prompt_token_ids
else:
# no ctx response, it's either a generation-only request or a generation-first disagg request
request.disaggregated_params = DisaggregatedParams(
request_type="generation_only",
ctx_request_id=disagg_request_id,
disagg_request_id=disagg_request_id,
schedule_style=self._schedule_style,
)
if ctx_server_info and "server_info" in ctx_server_info:
disaggregated_params = ctx_server_info["server_info"].get("disaggregated_params", {})
if disaggregated_params:
request.disaggregated_params = request.disaggregated_params.model_copy(
update=disaggregated_params
)
request.disaggregated_params.disagg_request_id = disagg_request_id
return request
async def _check_conditional_disagg(self, request: UCompletionRequest) -> bool:
@ -322,3 +351,37 @@ class OpenAIDisaggregatedService(OpenAIService):
"Invalid disaggregated params in context phase response. disagg_request_id is None"
)
return ctx_response
async def _send_disagg_request_gen_first(
self, request: UCompletionRequest, hooks: Optional[ResponseHooks] = None
) -> UCompletionResponse:
if hooks:
hooks.on_req_begin(request)
# empty server means client decides which server to use
need_ctx = not (await self._check_gen_only_disagg(request))
ctx_server, gen_server = None, None
ctx_server_info = None
tasks = []
ctx_req, gen_req = None, None
disagg_request_id = get_global_disagg_request_id(self._config.node_id)
if need_ctx:
ctx_server, ctx_server_info = await self._ctx_router.get_next_server(request)
ctx_req = self._get_ctx_request(request, disagg_request_id)
tasks.append(
asyncio.create_task(
self._ctx_client.send_request(ctx_req, server=ctx_server, hooks=hooks)
)
)
gen_req = self._get_gen_request(
request,
ctx_response=None,
disagg_request_id=disagg_request_id,
ctx_server_info=ctx_server_info,
)
tasks.append(
asyncio.create_task(
self._gen_client.send_request(gen_req, server=gen_server, hooks=hooks)
)
)
responses = await asyncio.gather(*tasks)
return responses[-1]

View File

@ -36,7 +36,8 @@ from typing_extensions import Annotated, Required, TypeAlias, TypedDict
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
from tensorrt_llm.llmapi import (DisaggScheduleStyle, GuidedDecodingParams,
SamplingParams)
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
@ -121,6 +122,7 @@ class DisaggregatedParams(OpenAIBaseModel):
disagg_request_id: Optional[int] = None
ctx_dp_rank: Optional[int] = None
ctx_info_endpoint: Optional[str] = None
schedule_style: Optional[DisaggScheduleStyle] = None
class ErrorResponse(OpenAIBaseModel):
@ -1095,7 +1097,9 @@ def to_disaggregated_params(
draft_tokens=tllm_disagg_params.draft_tokens,
disagg_request_id=tllm_disagg_params.disagg_request_id,
ctx_dp_rank=tllm_disagg_params.ctx_dp_rank,
ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint)
ctx_info_endpoint=tllm_disagg_params.ctx_info_endpoint,
schedule_style=tllm_disagg_params.schedule_style,
)
def to_llm_disaggregated_params(
@ -1111,7 +1115,9 @@ def to_llm_disaggregated_params(
draft_tokens=disaggregated_params.draft_tokens,
disagg_request_id=disaggregated_params.disagg_request_id,
ctx_dp_rank=disaggregated_params.ctx_dp_rank,
ctx_info_endpoint=disaggregated_params.ctx_info_endpoint)
ctx_info_endpoint=disaggregated_params.ctx_info_endpoint,
schedule_style=disaggregated_params.schedule_style,
)
UCompletionRequest = Union[CompletionRequest, ChatCompletionRequest]

View File

@ -290,6 +290,9 @@ class OpenAIServer:
self.app.add_api_route("/update_weights",
self.update_weights,
methods=["POST"])
self.app.add_api_route("/server_info",
self.get_server_info,
methods=["GET"])
if self.llm.args.return_perf_metrics:
# register /prometheus/metrics
self.mount_metrics()
@ -1131,6 +1134,9 @@ class OpenAIServer:
await self.llm.collective_rpc('update_weights', args=(request.weights,))
return JSONResponse(content={"status": "success"})
async def get_server_info(self) -> JSONResponse:
return JSONResponse(content={"disaggregated_params": self.llm.disaggregated_params})
async def __call__(self, host, port, sockets: list[socket.socket] | None = None):
# Store the binding address for server registration
self.binding_addr = f"http://{host}:{port}"

View File

@ -156,6 +156,7 @@ class Router(ABC):
**kwargs):
self._servers = servers or []
self._metadata_server = metadata_server
self._server_info: dict[str, dict] = {}
self._server_role = server_role
self._lock = asyncio.Lock()
self._monitor_task = None
@ -176,10 +177,26 @@ class Router(ABC):
def servers(self) -> List[str]:
return self._servers
async def _fetch_server_info(self, server: str, timeout: float) -> dict:
session = aiohttp.ClientSession()
try:
async with session.get(f"http://{server}/server_info",
timeout=timeout) as response:
return await response.json()
except Exception as e:
logger.error(f"Error fetching server info for server {server}: {e}")
finally:
await session.close()
return {}
async def _prepare_server(self, server: str):
if self._server_preparation_func:
await self._server_preparation_func(server)
self._server_info[server] = await self._fetch_server_info(
server, self._health_check_timeout)
logger.info(f"server is ready with info: {self._server_info[server]}")
async def prepare_servers(self, servers: Optional[List[str]] = None):
for server in servers or self._servers:
await self._prepare_server(server)
@ -207,6 +224,7 @@ class Router(ABC):
old_server for old_server in old_servers if old_server != server
]
self._on_servers_updated(old_servers, self._servers)
self._server_info.pop(server, None)
logger.debug(
f"Removed server {server}, current server list: {self._servers}")
@ -471,7 +489,7 @@ class RoundRobinRouter(Router):
raise ValueError(
f"No available servers after excluding {exclude_server}"
)
return server, {}
return server, {"server_info": self._server_info.get(server, {})}
async def finish_request(self, request: OpenAIRequest):
pass
@ -558,7 +576,7 @@ class LoadBalancingRouter(Router):
self._req_routing_table[id(request)] = server
return server, {}
return server, {"server_info": self._server_info.get(server, {})}
def _get_server_load(self, server):
return self._server_state[server]._num_active_tokens if self._use_tokens \
@ -675,6 +693,7 @@ class KvCacheAwareRouter(Router):
"block_hashes": block_hashes, # list[list[int]]
"token_lists": token_lists, # list[list[int]]
"matches": matches, # list[int]
"server_info": self._server_info.get(server, {}),
}
async def finish_request(self,

View File

@ -31,6 +31,7 @@ l0_a10:
- unittest/others/test_tracing.py
- unittest/disaggregated/test_disagg_openai_client.py
- unittest/disaggregated/test_disagg_utils.py
- unittest/disaggregated/test_openai_disagg_service.py
- unittest/disaggregated/test_router.py
- unittest/disaggregated/test_remoteDictionary.py
- unittest/disaggregated/test_agent_multi_backends.py

View File

@ -311,3 +311,6 @@ properties:
llm_id:
annotation: str
default: inspect._empty
disaggregated_params:
annotation: dict
default: None

View File

@ -0,0 +1,153 @@
import asyncio
from unittest.mock import AsyncMock
import pytest
from tensorrt_llm.llmapi.disagg_utils import DisaggServerConfig
from tensorrt_llm.serve.openai_disagg_service import OpenAIDisaggregatedService
from tensorrt_llm.serve.openai_protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
DisaggregatedParams,
DisaggScheduleStyle,
UsageInfo,
)
from tensorrt_llm.serve.router import Router
def _client_factory(*_args, **_kwargs):
return AsyncMock()
def _make_service(schedule_style: str) -> OpenAIDisaggregatedService:
config = DisaggServerConfig(server_configs=[], schedule_style=schedule_style)
ctx_router = AsyncMock(spec=Router)
gen_router = AsyncMock(spec=Router)
return OpenAIDisaggregatedService(
config, ctx_router, gen_router, client_factory=_client_factory
)
def _make_completion_response(
text: str,
finish_reason: str,
disagg_request_id: int = 42,
prompt_token_ids=None,
context_only=True,
) -> CompletionResponse:
if prompt_token_ids is None:
prompt_token_ids = [1, 2, 3]
return CompletionResponse(
model="test-model",
usage=UsageInfo(prompt_tokens=1, completion_tokens=1),
prompt_token_ids=prompt_token_ids,
choices=[
CompletionResponseChoice(
index=0,
text=text,
finish_reason=finish_reason,
disaggregated_params=DisaggregatedParams(
request_type="context_only" if context_only else "generation_only",
disagg_request_id=disagg_request_id,
ctx_request_id=disagg_request_id,
),
)
],
)
async def _mock_streaming_response(chunks):
for chunk in chunks:
await asyncio.sleep(0)
yield chunk
@pytest.mark.asyncio
@pytest.mark.parametrize("stream", [False, True], ids=["non-streaming", "streaming"])
@pytest.mark.parametrize("schedule_style", ["context_first", "generation_first"])
async def test_send_disagg_request(monkeypatch, stream, schedule_style):
# simulate different order of ctx and gen responses in gen first mode
ctx_gen_delay = [
(0.2, 0.1),
(0.2, 0.2),
(0.1, 0.2),
]
monkeypatch.delenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY", raising=False)
service = _make_service(schedule_style)
if schedule_style == "context_first":
assert service._send_disagg_request == service._send_disagg_request_ctx_first
assert service._schedule_style == DisaggScheduleStyle.CONTEXT_FIRST
else:
assert service._send_disagg_request == service._send_disagg_request_gen_first
assert service._schedule_style == DisaggScheduleStyle.GENERATION_FIRST
opaque_state = "opaque_state"
for i, (ctx_delay, gen_delay) in enumerate(ctx_gen_delay):
stream_chunks = [b"data: gen-0\n\n", b"data: gen-1\n\n"]
service._ctx_client = AsyncMock()
service._gen_client = AsyncMock()
ctx_server_info = {
"server_info": {"disaggregated_params": {"encoded_opaque_state": opaque_state}}
}
service._ctx_router.get_next_server = AsyncMock(return_value=("ctx:9000", ctx_server_info))
service._gen_router.get_next_server = AsyncMock(
return_value=("gen:9001", {"server_info": {}})
)
resp_text = f"response-{i}"
async def _delayed_ctx_response(*_args, **_kwargs):
request = _args[0]
server, info = await service._ctx_router.get_next_server(request)
await asyncio.sleep(ctx_delay)
return _make_completion_response(
resp_text,
finish_reason="length",
disagg_request_id=request.disaggregated_params.disagg_request_id,
context_only=True,
)
async def _delayed_gen_response(*_args, **_kwargs):
request = _args[0]
server, info = await service._gen_router.get_next_server(request)
await asyncio.sleep(gen_delay)
if stream:
return _mock_streaming_response(stream_chunks)
return _make_completion_response(
resp_text,
finish_reason="stop",
disagg_request_id=request.disaggregated_params.disagg_request_id,
context_only=False,
)
service._ctx_client.send_request = AsyncMock(side_effect=_delayed_ctx_response)
service._gen_client.send_request = AsyncMock(side_effect=_delayed_gen_response)
request = CompletionRequest(model="test-model", prompt="hello", stream=stream)
result = await service._send_disagg_request(request)
ctx_req = service._ctx_client.send_request.call_args.args[0]
assert ctx_req.disaggregated_params.request_type == "context_only"
gen_req = service._gen_client.send_request.call_args.args[0]
assert gen_req.disaggregated_params.request_type == "generation_only"
if schedule_style == "generation_first":
assert gen_req.disaggregated_params.encoded_opaque_state == opaque_state
assert (
gen_req.disaggregated_params.ctx_request_id
== ctx_req.disaggregated_params.disagg_request_id
)
if stream:
assert hasattr(result, "__aiter__")
chunks = [chunk async for chunk in result]
assert chunks == stream_chunks
else:
assert result.model == "test-model"
assert result.usage.prompt_tokens == 1
assert len(result.choices) == 1
assert result.choices[0].text == resp_text
assert result.choices[0].finish_reason == "stop"
assert (
result.choices[0].disaggregated_params.disagg_request_id
== ctx_req.disaggregated_params.disagg_request_id
)