[https://nvbugs/5834212][fix] prevent routing ctx and gen requests to the same worker; update doc for unique disagg ID (#11095)

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2026-02-02 09:54:33 +08:00 committed by GitHub
parent ea49afdf0b
commit b00e8338ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 114 additions and 28 deletions

View File

@ -6,6 +6,7 @@
- [NIXL Backend Configuration](#nixl-backend-configuration)
- [Overlap Optimization](#Overlap-Optimization)
- [Cache Layout Transformation](#Cache-Layout-Transformation)
- [Unique Global Request ID](`#Unique-Global-Request-ID`)
- [Usage](#Usage)
- [Dynamo](#Dynamo)
- [trtllm-serve](#trtllm-serve)
@ -90,6 +91,14 @@ To minimize KV cache transmission latency, TensorRT LLM currently uses direct tr
The optimizations required for KV cache transmission vary depending on whether it's single-node multi-GPU, multi-node multi-GPU, or different GPU models. To accommodate this, TensorRT LLM provides a set of environment variables for selection in different environments. Please refer to the following section for details [Environment Variables](#Environment-Variables).
### Unique Global Request ID
A disaggregated-serving request can provide a unique global request ID via `DisaggregatedParams.disagg_request_id`.
When this field is a positive integer, the context and generation requests share that value as their internal request ID, which enables end-to-end tracking.
To avoid collisions with worker-local or warm-up requests, it is recommended to use a value larger than `1 << 42 = 4398046511104`.
If the field is unset or non-positive, the context and generation requests instead receive separate local sequence IDs, rotating within the range `(0, 1<<42)`, assigned by the respective workers. When `disagg_request_id` is specified, do not route the context and generation requests to the same worker.
This field is optional at present; however, some forthcoming features will depend on this unique identifier.
## Usage
### Dynamo

View File

@ -108,10 +108,9 @@ class OpenAIDisaggregatedService(OpenAIService):
if hooks:
hooks.on_req_begin(request)
# empty server means client decides which server to use
reserved_gen_server = None
reserved_ctx_server = None
ctx_server = None
# reserve a gen_server if conditional disagg is needed
reserved_gen_server, need_ctx = await self._check_conditional_disagg(request)
gen_server, need_ctx = await self._check_conditional_disagg(request)
need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
ctx_response = None
gen_req = request
@ -119,15 +118,20 @@ class OpenAIDisaggregatedService(OpenAIService):
if need_ctx:
ctx_req = self._get_ctx_request(request, disagg_request_id)
# ctx generator is empty
ctx_server, _ = await self._ctx_router.get_next_server(
ctx_req, exclude_server=gen_server
)
ctx_response = await self._ctx_client.send_request(
ctx_req, server=reserved_ctx_server, hooks=hooks
ctx_req, server=ctx_server, hooks=hooks
)
await self._verify_ctx_response(ctx_response)
gen_req = self._get_gen_request(request, ctx_response, disagg_request_id)
if ctx_response is None or self._need_gen(ctx_response):
return await self._gen_client.send_request(
gen_req, server=reserved_gen_server, hooks=hooks
)
if not gen_server:
gen_server, _ = await self._gen_router.get_next_server(
gen_req, exclude_server=ctx_server
)
return await self._gen_client.send_request(gen_req, server=gen_server, hooks=hooks)
else:
if request.stream:
# ctx client will never return a generator when streaming is requested

View File

@ -211,8 +211,11 @@ class Router(ABC):
f"Removed server {server}, current server list: {self._servers}")
@abstractmethod
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
'''Select server by request and return some intermediate information'''
async def get_next_server(
self,
request: OpenAIRequest,
exclude_server: Optional[str] = None) -> tuple[str, dict]:
'''Select server by request and return some intermediate information, exclude_server is a server to exclude from the selection'''
@abstractmethod
async def finish_request(self, request: OpenAIRequest):
@ -441,15 +444,17 @@ class RoundRobinRouter(Router):
self._server_idx = 0
def _on_servers_updated(self, old_servers, new_servers):
"""Reset the index when servers are removed to prevent index out of bounds errors."""
if len(new_servers) < len(old_servers):
# Servers were removed, reset the index
self._server_idx = 0
elif self._server_idx >= len(new_servers):
# Safety check: ensure index is always within bounds
self._server_idx = 0
pass
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
def _get_next_server(self) -> str:
server = self._servers[self._server_idx % len(self._servers)]
self._server_idx += 1
return server
async def get_next_server(
self,
request: OpenAIRequest,
exclude_server: Optional[str] = None) -> tuple[str, dict]:
if not self._servers:
if self._metadata_server:
raise ValueError(
@ -459,12 +464,13 @@ class RoundRobinRouter(Router):
raise ValueError(f"No {self._server_role} servers available")
async with self._lock:
# Safety check: ensure index is within bounds
if self._server_idx >= len(self._servers):
self._server_idx = 0
server = self._servers[self._server_idx]
self._server_idx = (self._server_idx + 1) % len(self._servers)
server = self._get_next_server()
if exclude_server and server == exclude_server:
server = self._get_next_server()
if server == exclude_server:
raise ValueError(
f"No available servers after excluding {exclude_server}"
)
return server, {}
async def finish_request(self, request: OpenAIRequest):
@ -517,7 +523,10 @@ class LoadBalancingRouter(Router):
heapq.heappush(self._server_load_heap,
(self._get_server_load(server), server))
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
async def get_next_server(
self,
request: OpenAIRequest,
exclude_server: Optional[str] = None) -> tuple[str, dict]:
if not self._servers:
if self._metadata_server:
raise ValueError(
@ -527,8 +536,23 @@ class LoadBalancingRouter(Router):
raise ValueError(f"No {self._server_role} servers available")
async with self._lock:
server = heapq.heappop(self._server_load_heap)[1]
if exclude_server:
server_load_heap = [(self._get_server_load(server), server)
for server in self._servers
if server != exclude_server]
heapq.heapify(server_load_heap)
else:
server_load_heap = self._server_load_heap
server = heapq.heappop(server_load_heap)[1]
await self._server_state[server].increment_load(request)
# maintain the member heap
if exclude_server:
self._server_load_heap = server_load_heap
if exclude_server in self._server_state:
heapq.heappush(
self._server_load_heap,
(self._get_server_load(exclude_server), exclude_server))
heapq.heappush(self._server_load_heap,
(self._get_server_load(server), server))
@ -604,9 +628,15 @@ class KvCacheAwareRouter(Router):
tokenizer = self._tokenizers[request.model]
return [tokenizer(prompt)["input_ids"] for prompt in prompts]
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
async def get_next_server(
self,
request: OpenAIRequest,
exclude_server: Optional[str] = None) -> tuple[str, dict]:
async with self._lock:
servers = list(self._server_state.keys())
servers = list([
server for server in self._server_state.keys()
if server != exclude_server
])
token_lists = self._tokenize(request)
block_hashes: list[list[int]] = []
for token_list in token_lists:

View File

@ -6,13 +6,14 @@ backend: "pytorch"
cuda_graph_config: null
disable_overlap_scheduler: True
context_servers:
num_instances: 1
num_instances: 2
tensor_parallel_size: 1
pipeline_parallel_size: 1
cache_transceiver_config:
backend: DEFAULT
urls:
- "localhost:8001"
- "localhost:8002"
generation_servers:
num_instances: 2
tensor_parallel_size: 1

View File

@ -424,3 +424,45 @@ async def test_server_health_check(mock_metadata_server, router_class):
live_servers = await router.check_servers_health(servers)
assert len(live_servers) == 1, "Should have one healthy server"
assert server_url2 in live_servers, "Second server should still be present"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"router_class", [RoundRobinRouter, LoadBalancingRouter, KvCacheAwareRouter])
async def test_get_next_server_exclude_server(router_class):
servers = ["server1", "server2", "server3"]
router = router_class(server_role="context", servers=servers)
exclude_server2 = {server: 0 for server in servers}
exclude_server3 = {server: 0 for server in servers}
for _ in range(0, 10):
server, _ = await router.get_next_server(CompletionRequest(
model="TinyLlama", prompt=[[10] * 10]),
exclude_server="server2")
exclude_server2[server] += 1
server, _ = await router.get_next_server(CompletionRequest(
model="TinyLlama", prompt=[[10] * 10]),
exclude_server="server3")
exclude_server3[server] += 1
if router_class == KvCacheAwareRouter:
# KvCacheAwareRouter is not load-balanced
assert exclude_server2["server2"] == 0
assert exclude_server3["server3"] == 0
else:
assert exclude_server2["server1"] > 0 and exclude_server2[
"server2"] == 0 and exclude_server2["server3"] > 0, exclude_server2
assert exclude_server3["server1"] > 0 and exclude_server3[
"server2"] > 0 and exclude_server3["server3"] == 0, exclude_server3
@pytest.mark.asyncio
@pytest.mark.parametrize(
"router_class", [RoundRobinRouter, LoadBalancingRouter, KvCacheAwareRouter])
async def test_get_next_server_exclude_server_insufficient(router_class):
servers = ["server1"]
router = router_class(server_role="context",
servers=servers,
use_tokens=False)
with pytest.raises(Exception):
await router.get_next_server(CompletionRequest(model="TinyLlama",
prompt=[[10] * 10]),
exclude_server=servers[0])