mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-27 14:13:34 +08:00
* Add draft scheduler class Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> * Refactor the design Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> * feat: Introduce router class for disaggregated server Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> * Add unit tests for router class Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> * Adding tests for disagg_utils Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> * Fixing missing import Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> * Fixing disagg integration tests Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> * Addressing MR review comments Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> --------- Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
178 lines
6.4 KiB
Python
178 lines
6.4 KiB
Python
import asyncio
|
|
import heapq
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Union
|
|
|
|
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
|
CompletionRequest)
|
|
|
|
|
|
def get_request_num_tokens(
|
|
request: Union[CompletionRequest, ChatCompletionRequest]) -> int:
|
|
if request.disaggregated_params.request_type == "context_only":
|
|
if isinstance(request, ChatCompletionRequest):
|
|
raise ValueError(
|
|
"LoadBalancing router with tokens doesn't support ChatCompletionRequest yet"
|
|
)
|
|
|
|
if isinstance(request.prompt, str) or \
|
|
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
|
|
prompts = [request.prompt]
|
|
else:
|
|
prompts = request.prompt
|
|
|
|
num_tokens = sum(len(prompt) for prompt in prompts)
|
|
elif request.disaggregated_params.request_type == "generation_only":
|
|
raise ValueError(
|
|
"LoadBalancing router with tokens doesn't support generation_only requests"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported request type: {request.disaggregated_params.request_type}"
|
|
)
|
|
|
|
return num_tokens
|
|
|
|
|
|
class ServerState:
|
|
|
|
def __init__(self, server: str, use_tokens: bool = False):
|
|
self._server = server
|
|
self._num_active_requests = 0
|
|
self._num_active_tokens = 0
|
|
self._use_tokens = use_tokens
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def increment_load(self, request: Union[CompletionRequest,
|
|
ChatCompletionRequest]):
|
|
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
|
|
async with self._lock:
|
|
self._num_active_requests += 1
|
|
self._num_active_tokens += num_tokens
|
|
|
|
async def decrement_load(self, request: Union[CompletionRequest,
|
|
ChatCompletionRequest]):
|
|
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
|
|
async with self._lock:
|
|
self._num_active_requests -= 1
|
|
self._num_active_tokens -= num_tokens
|
|
|
|
|
|
class Router(ABC):
|
|
|
|
def __init__(self, servers: List[str] = None):
|
|
self._servers = servers
|
|
|
|
@abstractmethod
|
|
async def get_next_server(
|
|
self, request: Union[CompletionRequest,
|
|
ChatCompletionRequest]) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def finish_request(self, request: Union[CompletionRequest,
|
|
ChatCompletionRequest]):
|
|
pass
|
|
|
|
|
|
class RoundRobinRouter(Router):
|
|
|
|
def __init__(self, servers: List[str] = None):
|
|
super().__init__(servers)
|
|
self._server_idx = 0
|
|
|
|
async def get_next_server(
|
|
self, request: Union[CompletionRequest,
|
|
ChatCompletionRequest]) -> str:
|
|
server = self._servers[self._server_idx]
|
|
self._server_idx = (self._server_idx + 1) % len(self._servers)
|
|
return server
|
|
|
|
async def finish_request(self, request: Union[CompletionRequest,
|
|
ChatCompletionRequest]):
|
|
pass
|
|
|
|
|
|
class LoadBalancingRouter(Router):
|
|
|
|
def __init__(self, servers: List[str] = None, use_tokens: bool = False):
|
|
super().__init__(servers)
|
|
self._lock = asyncio.Lock()
|
|
# Load map between servers and their number of tokens processed
|
|
self._server_state = {}
|
|
self._server_load_heap = []
|
|
|
|
# Routing table to map requests to servers
|
|
self._req_routing_table = {}
|
|
|
|
self._use_tokens = use_tokens
|
|
self._init_heap()
|
|
|
|
def _init_heap(self):
|
|
for server in self._servers:
|
|
self._server_state[server] = ServerState(server, self._use_tokens)
|
|
heapq.heappush(self._server_load_heap,
|
|
(self._get_server_load(server), server))
|
|
|
|
async def get_next_server(
|
|
self, request: Union[CompletionRequest,
|
|
ChatCompletionRequest]) -> str:
|
|
async with self._lock:
|
|
server = heapq.heappop(self._server_load_heap)[1]
|
|
await self._server_state[server].increment_load(request)
|
|
heapq.heappush(self._server_load_heap,
|
|
(self._get_server_load(server), server))
|
|
|
|
self._req_routing_table[id(request)] = server
|
|
|
|
return server
|
|
|
|
def _get_server_load(self, server):
|
|
return self._server_state[server]._num_active_tokens if self._use_tokens \
|
|
else self._server_state[server]._num_active_requests
|
|
|
|
async def finish_request(self, request: Union[CompletionRequest,
|
|
ChatCompletionRequest]):
|
|
async with self._lock:
|
|
server = self._req_routing_table[id(request)]
|
|
await self._server_state[server].decrement_load(request)
|
|
heapq.heappush(self._server_load_heap,
|
|
(self._get_server_load(server), server))
|
|
del self._req_routing_table[id(request)]
|
|
|
|
|
|
def create_router(router_type: str, servers: List[str]) -> Router:
|
|
"""
|
|
Factory function to create different types of router instances.
|
|
|
|
Args:
|
|
router_type (str): Type of router to create. Supported values:
|
|
- "round_robin": Creates a RoundRobinRouter
|
|
- "requests_load_balancing": Creates a LoadBalancingRouter, which balances requests across instances
|
|
- "tokens_load_balancing": Creates a LoadBalancingRouter, which balances tokens across instances
|
|
servers: List of server URLs
|
|
|
|
Returns:
|
|
Router: An instance of the requested router type
|
|
|
|
Raises:
|
|
ValueError: If an unsupported router type is provided
|
|
"""
|
|
|
|
router_map = {
|
|
"round_robin": RoundRobinRouter,
|
|
"requests_load_balancing": LoadBalancingRouter,
|
|
"tokens_load_balancing": LoadBalancingRouter
|
|
}
|
|
|
|
router_class = router_map.get(router_type.lower())
|
|
if router_class is None:
|
|
raise ValueError(f"Unsupported router type: {router_type}. "
|
|
f"Supported types are: {list(router_map.keys())}")
|
|
|
|
if router_type.endswith("load_balancing"):
|
|
use_tokens = True if router_type.startswith("tokens") else False
|
|
return router_class(servers, use_tokens=use_tokens)
|
|
else:
|
|
return router_class(servers)
|