mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[https://nvbugs/5834212][fix] prevent routing ctx and gen requests to the same worker; update doc for unique disagg ID (#11095)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
ea49afdf0b
commit
b00e8338ec
@ -6,6 +6,7 @@
|
||||
- [NIXL Backend Configuration](#nixl-backend-configuration)
|
||||
- [Overlap Optimization](#Overlap-Optimization)
|
||||
- [Cache Layout Transformation](#Cache-Layout-Transformation)
|
||||
- [Unique Global Request ID](`#Unique-Global-Request-ID`)
|
||||
- [Usage](#Usage)
|
||||
- [Dynamo](#Dynamo)
|
||||
- [trtllm-serve](#trtllm-serve)
|
||||
@ -90,6 +91,14 @@ To minimize KV cache transmission latency, TensorRT LLM currently uses direct tr
|
||||
|
||||
The optimizations required for KV cache transmission vary depending on whether it's single-node multi-GPU, multi-node multi-GPU, or different GPU models. To accommodate this, TensorRT LLM provides a set of environment variables for selection in different environments. Please refer to the following section for details [Environment Variables](#Environment-Variables).
|
||||
|
||||
### Unique Global Request ID
|
||||
|
||||
A disaggregated-serving request can provide a unique global request ID via `DisaggregatedParams.disagg_request_id`.
|
||||
When this field is a positive integer, the context and generation requests share that value as their internal request ID, which enables end-to-end tracking.
|
||||
To avoid collisions with worker-local or warm-up requests, it is recommended to use a value larger than `1 << 42 = 4398046511104`.
|
||||
If the field is unset or non-positive, the context and generation requests instead receive separate local sequence IDs, rotating within the range `(0, 1<<42)`, assigned by the respective workers. When `disagg_request_id` is specified, do not route the context and generation requests to the same worker.
|
||||
This field is optional at present; however, some forthcoming features will depend on this unique identifier.
|
||||
|
||||
## Usage
|
||||
|
||||
### Dynamo
|
||||
|
||||
@ -108,10 +108,9 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
if hooks:
|
||||
hooks.on_req_begin(request)
|
||||
# empty server means client decides which server to use
|
||||
reserved_gen_server = None
|
||||
reserved_ctx_server = None
|
||||
ctx_server = None
|
||||
# reserve a gen_server if conditional disagg is needed
|
||||
reserved_gen_server, need_ctx = await self._check_conditional_disagg(request)
|
||||
gen_server, need_ctx = await self._check_conditional_disagg(request)
|
||||
need_ctx = need_ctx and not await self._check_gen_only_disagg(request)
|
||||
ctx_response = None
|
||||
gen_req = request
|
||||
@ -119,15 +118,20 @@ class OpenAIDisaggregatedService(OpenAIService):
|
||||
if need_ctx:
|
||||
ctx_req = self._get_ctx_request(request, disagg_request_id)
|
||||
# ctx generator is empty
|
||||
ctx_server, _ = await self._ctx_router.get_next_server(
|
||||
ctx_req, exclude_server=gen_server
|
||||
)
|
||||
ctx_response = await self._ctx_client.send_request(
|
||||
ctx_req, server=reserved_ctx_server, hooks=hooks
|
||||
ctx_req, server=ctx_server, hooks=hooks
|
||||
)
|
||||
await self._verify_ctx_response(ctx_response)
|
||||
gen_req = self._get_gen_request(request, ctx_response, disagg_request_id)
|
||||
if ctx_response is None or self._need_gen(ctx_response):
|
||||
return await self._gen_client.send_request(
|
||||
gen_req, server=reserved_gen_server, hooks=hooks
|
||||
)
|
||||
if not gen_server:
|
||||
gen_server, _ = await self._gen_router.get_next_server(
|
||||
gen_req, exclude_server=ctx_server
|
||||
)
|
||||
return await self._gen_client.send_request(gen_req, server=gen_server, hooks=hooks)
|
||||
else:
|
||||
if request.stream:
|
||||
# ctx client will never return a generator when streaming is requested
|
||||
|
||||
@ -211,8 +211,11 @@ class Router(ABC):
|
||||
f"Removed server {server}, current server list: {self._servers}")
|
||||
|
||||
@abstractmethod
|
||||
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
||||
'''Select server by request and return some intermediate information'''
|
||||
async def get_next_server(
|
||||
self,
|
||||
request: OpenAIRequest,
|
||||
exclude_server: Optional[str] = None) -> tuple[str, dict]:
|
||||
'''Select server by request and return some intermediate information, exclude_server is a server to exclude from the selection'''
|
||||
|
||||
@abstractmethod
|
||||
async def finish_request(self, request: OpenAIRequest):
|
||||
@ -441,15 +444,17 @@ class RoundRobinRouter(Router):
|
||||
self._server_idx = 0
|
||||
|
||||
def _on_servers_updated(self, old_servers, new_servers):
|
||||
"""Reset the index when servers are removed to prevent index out of bounds errors."""
|
||||
if len(new_servers) < len(old_servers):
|
||||
# Servers were removed, reset the index
|
||||
self._server_idx = 0
|
||||
elif self._server_idx >= len(new_servers):
|
||||
# Safety check: ensure index is always within bounds
|
||||
self._server_idx = 0
|
||||
pass
|
||||
|
||||
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
||||
def _get_next_server(self) -> str:
|
||||
server = self._servers[self._server_idx % len(self._servers)]
|
||||
self._server_idx += 1
|
||||
return server
|
||||
|
||||
async def get_next_server(
|
||||
self,
|
||||
request: OpenAIRequest,
|
||||
exclude_server: Optional[str] = None) -> tuple[str, dict]:
|
||||
if not self._servers:
|
||||
if self._metadata_server:
|
||||
raise ValueError(
|
||||
@ -459,12 +464,13 @@ class RoundRobinRouter(Router):
|
||||
raise ValueError(f"No {self._server_role} servers available")
|
||||
|
||||
async with self._lock:
|
||||
# Safety check: ensure index is within bounds
|
||||
if self._server_idx >= len(self._servers):
|
||||
self._server_idx = 0
|
||||
|
||||
server = self._servers[self._server_idx]
|
||||
self._server_idx = (self._server_idx + 1) % len(self._servers)
|
||||
server = self._get_next_server()
|
||||
if exclude_server and server == exclude_server:
|
||||
server = self._get_next_server()
|
||||
if server == exclude_server:
|
||||
raise ValueError(
|
||||
f"No available servers after excluding {exclude_server}"
|
||||
)
|
||||
return server, {}
|
||||
|
||||
async def finish_request(self, request: OpenAIRequest):
|
||||
@ -517,7 +523,10 @@ class LoadBalancingRouter(Router):
|
||||
heapq.heappush(self._server_load_heap,
|
||||
(self._get_server_load(server), server))
|
||||
|
||||
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
||||
async def get_next_server(
|
||||
self,
|
||||
request: OpenAIRequest,
|
||||
exclude_server: Optional[str] = None) -> tuple[str, dict]:
|
||||
if not self._servers:
|
||||
if self._metadata_server:
|
||||
raise ValueError(
|
||||
@ -527,8 +536,23 @@ class LoadBalancingRouter(Router):
|
||||
raise ValueError(f"No {self._server_role} servers available")
|
||||
|
||||
async with self._lock:
|
||||
server = heapq.heappop(self._server_load_heap)[1]
|
||||
if exclude_server:
|
||||
server_load_heap = [(self._get_server_load(server), server)
|
||||
for server in self._servers
|
||||
if server != exclude_server]
|
||||
heapq.heapify(server_load_heap)
|
||||
else:
|
||||
server_load_heap = self._server_load_heap
|
||||
|
||||
server = heapq.heappop(server_load_heap)[1]
|
||||
await self._server_state[server].increment_load(request)
|
||||
# maintain the member heap
|
||||
if exclude_server:
|
||||
self._server_load_heap = server_load_heap
|
||||
if exclude_server in self._server_state:
|
||||
heapq.heappush(
|
||||
self._server_load_heap,
|
||||
(self._get_server_load(exclude_server), exclude_server))
|
||||
heapq.heappush(self._server_load_heap,
|
||||
(self._get_server_load(server), server))
|
||||
|
||||
@ -604,9 +628,15 @@ class KvCacheAwareRouter(Router):
|
||||
tokenizer = self._tokenizers[request.model]
|
||||
return [tokenizer(prompt)["input_ids"] for prompt in prompts]
|
||||
|
||||
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
||||
async def get_next_server(
|
||||
self,
|
||||
request: OpenAIRequest,
|
||||
exclude_server: Optional[str] = None) -> tuple[str, dict]:
|
||||
async with self._lock:
|
||||
servers = list(self._server_state.keys())
|
||||
servers = list([
|
||||
server for server in self._server_state.keys()
|
||||
if server != exclude_server
|
||||
])
|
||||
token_lists = self._tokenize(request)
|
||||
block_hashes: list[list[int]] = []
|
||||
for token_list in token_lists:
|
||||
|
||||
@ -6,13 +6,14 @@ backend: "pytorch"
|
||||
cuda_graph_config: null
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
num_instances: 2
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 1
|
||||
cache_transceiver_config:
|
||||
backend: DEFAULT
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
- "localhost:8002"
|
||||
generation_servers:
|
||||
num_instances: 2
|
||||
tensor_parallel_size: 1
|
||||
|
||||
@ -424,3 +424,45 @@ async def test_server_health_check(mock_metadata_server, router_class):
|
||||
live_servers = await router.check_servers_health(servers)
|
||||
assert len(live_servers) == 1, "Should have one healthy server"
|
||||
assert server_url2 in live_servers, "Second server should still be present"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"router_class", [RoundRobinRouter, LoadBalancingRouter, KvCacheAwareRouter])
|
||||
async def test_get_next_server_exclude_server(router_class):
|
||||
servers = ["server1", "server2", "server3"]
|
||||
router = router_class(server_role="context", servers=servers)
|
||||
exclude_server2 = {server: 0 for server in servers}
|
||||
exclude_server3 = {server: 0 for server in servers}
|
||||
for _ in range(0, 10):
|
||||
server, _ = await router.get_next_server(CompletionRequest(
|
||||
model="TinyLlama", prompt=[[10] * 10]),
|
||||
exclude_server="server2")
|
||||
exclude_server2[server] += 1
|
||||
server, _ = await router.get_next_server(CompletionRequest(
|
||||
model="TinyLlama", prompt=[[10] * 10]),
|
||||
exclude_server="server3")
|
||||
exclude_server3[server] += 1
|
||||
if router_class == KvCacheAwareRouter:
|
||||
# KvCacheAwareRouter is not load-balanced
|
||||
assert exclude_server2["server2"] == 0
|
||||
assert exclude_server3["server3"] == 0
|
||||
else:
|
||||
assert exclude_server2["server1"] > 0 and exclude_server2[
|
||||
"server2"] == 0 and exclude_server2["server3"] > 0, exclude_server2
|
||||
assert exclude_server3["server1"] > 0 and exclude_server3[
|
||||
"server2"] > 0 and exclude_server3["server3"] == 0, exclude_server3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"router_class", [RoundRobinRouter, LoadBalancingRouter, KvCacheAwareRouter])
|
||||
async def test_get_next_server_exclude_server_insufficient(router_class):
|
||||
servers = ["server1"]
|
||||
router = router_class(server_role="context",
|
||||
servers=servers,
|
||||
use_tokens=False)
|
||||
with pytest.raises(Exception):
|
||||
await router.get_next_server(CompletionRequest(model="TinyLlama",
|
||||
prompt=[[10] * 10]),
|
||||
exclude_server=servers[0])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user