diff --git a/docs/source/features/disagg-serving.md b/docs/source/features/disagg-serving.md index 5811a52109..88feb11b08 100644 --- a/docs/source/features/disagg-serving.md +++ b/docs/source/features/disagg-serving.md @@ -6,6 +6,7 @@ - [NIXL Backend Configuration](#nixl-backend-configuration) - [Overlap Optimization](#Overlap-Optimization) - [Cache Layout Transformation](#Cache-Layout-Transformation) + - [Unique Global Request ID](`#Unique-Global-Request-ID`) - [Usage](#Usage) - [Dynamo](#Dynamo) - [trtllm-serve](#trtllm-serve) @@ -90,6 +91,14 @@ To minimize KV cache transmission latency, TensorRT LLM currently uses direct tr The optimizations required for KV cache transmission vary depending on whether it's single-node multi-GPU, multi-node multi-GPU, or different GPU models. To accommodate this, TensorRT LLM provides a set of environment variables for selection in different environments. Please refer to the following section for details [Environment Variables](#Environment-Variables). +### Unique Global Request ID + +A disaggregated-serving request can provide a unique global request ID via `DisaggregatedParams.disagg_request_id`. +When this field is a positive integer, the context and generation requests share that value as their internal request ID, which enables end-to-end tracking. +To avoid collisions with worker-local or warm-up requests, it is recommended to use a value larger than `1 << 42 = 4398046511104`. +If the field is unset or non-positive, the context and generation requests instead receive separate local sequence IDs, rotating within the range `(0, 1<<42)`, assigned by the respective workers. When `disagg_request_id` is specified, do not route the context and generation requests to the same worker. +This field is optional at present; however, some forthcoming features will depend on this unique identifier. + ## Usage ### Dynamo diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index b9e88fe5f5..9f86974e50 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -108,10 +108,9 @@ class OpenAIDisaggregatedService(OpenAIService): if hooks: hooks.on_req_begin(request) # empty server means client decides which server to use - reserved_gen_server = None - reserved_ctx_server = None + ctx_server = None # reserve a gen_server if conditional disagg is needed - reserved_gen_server, need_ctx = await self._check_conditional_disagg(request) + gen_server, need_ctx = await self._check_conditional_disagg(request) need_ctx = need_ctx and not await self._check_gen_only_disagg(request) ctx_response = None gen_req = request @@ -119,15 +118,20 @@ class OpenAIDisaggregatedService(OpenAIService): if need_ctx: ctx_req = self._get_ctx_request(request, disagg_request_id) # ctx generator is empty + ctx_server, _ = await self._ctx_router.get_next_server( + ctx_req, exclude_server=gen_server + ) ctx_response = await self._ctx_client.send_request( - ctx_req, server=reserved_ctx_server, hooks=hooks + ctx_req, server=ctx_server, hooks=hooks ) await self._verify_ctx_response(ctx_response) gen_req = self._get_gen_request(request, ctx_response, disagg_request_id) if ctx_response is None or self._need_gen(ctx_response): - return await self._gen_client.send_request( - gen_req, server=reserved_gen_server, hooks=hooks - ) + if not gen_server: + gen_server, _ = await self._gen_router.get_next_server( + gen_req, exclude_server=ctx_server + ) + return await self._gen_client.send_request(gen_req, server=gen_server, hooks=hooks) else: if request.stream: # ctx client will never return a generator when streaming is requested diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index a3d3939886..a3bc0783f4 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -211,8 +211,11 @@ class Router(ABC): f"Removed server {server}, current server list: {self._servers}") @abstractmethod - async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]: - '''Select server by request and return some intermediate information''' + async def get_next_server( + self, + request: OpenAIRequest, + exclude_server: Optional[str] = None) -> tuple[str, dict]: + '''Select server by request and return some intermediate information, exclude_server is a server to exclude from the selection''' @abstractmethod async def finish_request(self, request: OpenAIRequest): @@ -441,15 +444,17 @@ class RoundRobinRouter(Router): self._server_idx = 0 def _on_servers_updated(self, old_servers, new_servers): - """Reset the index when servers are removed to prevent index out of bounds errors.""" - if len(new_servers) < len(old_servers): - # Servers were removed, reset the index - self._server_idx = 0 - elif self._server_idx >= len(new_servers): - # Safety check: ensure index is always within bounds - self._server_idx = 0 + pass - async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]: + def _get_next_server(self) -> str: + server = self._servers[self._server_idx % len(self._servers)] + self._server_idx += 1 + return server + + async def get_next_server( + self, + request: OpenAIRequest, + exclude_server: Optional[str] = None) -> tuple[str, dict]: if not self._servers: if self._metadata_server: raise ValueError( @@ -459,12 +464,13 @@ class RoundRobinRouter(Router): raise ValueError(f"No {self._server_role} servers available") async with self._lock: - # Safety check: ensure index is within bounds - if self._server_idx >= len(self._servers): - self._server_idx = 0 - - server = self._servers[self._server_idx] - self._server_idx = (self._server_idx + 1) % len(self._servers) + server = self._get_next_server() + if exclude_server and server == exclude_server: + server = self._get_next_server() + if server == exclude_server: + raise ValueError( + f"No available servers after excluding {exclude_server}" + ) return server, {} async def finish_request(self, request: OpenAIRequest): @@ -517,7 +523,10 @@ class LoadBalancingRouter(Router): heapq.heappush(self._server_load_heap, (self._get_server_load(server), server)) - async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]: + async def get_next_server( + self, + request: OpenAIRequest, + exclude_server: Optional[str] = None) -> tuple[str, dict]: if not self._servers: if self._metadata_server: raise ValueError( @@ -527,8 +536,23 @@ class LoadBalancingRouter(Router): raise ValueError(f"No {self._server_role} servers available") async with self._lock: - server = heapq.heappop(self._server_load_heap)[1] + if exclude_server: + server_load_heap = [(self._get_server_load(server), server) + for server in self._servers + if server != exclude_server] + heapq.heapify(server_load_heap) + else: + server_load_heap = self._server_load_heap + + server = heapq.heappop(server_load_heap)[1] await self._server_state[server].increment_load(request) + # maintain the member heap + if exclude_server: + self._server_load_heap = server_load_heap + if exclude_server in self._server_state: + heapq.heappush( + self._server_load_heap, + (self._get_server_load(exclude_server), exclude_server)) heapq.heappush(self._server_load_heap, (self._get_server_load(server), server)) @@ -604,9 +628,15 @@ class KvCacheAwareRouter(Router): tokenizer = self._tokenizers[request.model] return [tokenizer(prompt)["input_ids"] for prompt in prompts] - async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]: + async def get_next_server( + self, + request: OpenAIRequest, + exclude_server: Optional[str] = None) -> tuple[str, dict]: async with self._lock: - servers = list(self._server_state.keys()) + servers = list([ + server for server in self._server_state.keys() + if server != exclude_server + ]) token_lists = self._tokenize(request) block_hashes: list[list[int]] = [] for token_list in token_lists: diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml index 27d7ec4ee8..dcc40a6a8b 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml @@ -6,13 +6,14 @@ backend: "pytorch" cuda_graph_config: null disable_overlap_scheduler: True context_servers: - num_instances: 1 + num_instances: 2 tensor_parallel_size: 1 pipeline_parallel_size: 1 cache_transceiver_config: backend: DEFAULT urls: - "localhost:8001" + - "localhost:8002" generation_servers: num_instances: 2 tensor_parallel_size: 1 diff --git a/tests/unittest/disaggregated/test_router.py b/tests/unittest/disaggregated/test_router.py index eb4beffa4a..73310f5f03 100644 --- a/tests/unittest/disaggregated/test_router.py +++ b/tests/unittest/disaggregated/test_router.py @@ -424,3 +424,45 @@ async def test_server_health_check(mock_metadata_server, router_class): live_servers = await router.check_servers_health(servers) assert len(live_servers) == 1, "Should have one healthy server" assert server_url2 in live_servers, "Second server should still be present" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "router_class", [RoundRobinRouter, LoadBalancingRouter, KvCacheAwareRouter]) +async def test_get_next_server_exclude_server(router_class): + servers = ["server1", "server2", "server3"] + router = router_class(server_role="context", servers=servers) + exclude_server2 = {server: 0 for server in servers} + exclude_server3 = {server: 0 for server in servers} + for _ in range(0, 10): + server, _ = await router.get_next_server(CompletionRequest( + model="TinyLlama", prompt=[[10] * 10]), + exclude_server="server2") + exclude_server2[server] += 1 + server, _ = await router.get_next_server(CompletionRequest( + model="TinyLlama", prompt=[[10] * 10]), + exclude_server="server3") + exclude_server3[server] += 1 + if router_class == KvCacheAwareRouter: + # KvCacheAwareRouter is not load-balanced + assert exclude_server2["server2"] == 0 + assert exclude_server3["server3"] == 0 + else: + assert exclude_server2["server1"] > 0 and exclude_server2[ + "server2"] == 0 and exclude_server2["server3"] > 0, exclude_server2 + assert exclude_server3["server1"] > 0 and exclude_server3[ + "server2"] > 0 and exclude_server3["server3"] == 0, exclude_server3 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "router_class", [RoundRobinRouter, LoadBalancingRouter, KvCacheAwareRouter]) +async def test_get_next_server_exclude_server_insufficient(router_class): + servers = ["server1"] + router = router_class(server_role="context", + servers=servers, + use_tokens=False) + with pytest.raises(Exception): + await router.get_next_server(CompletionRequest(model="TinyLlama", + prompt=[[10] * 10]), + exclude_server=servers[0])