TensorRT-LLMs/tensorrt_llm/serve/router.py
2025-12-16 05:16:32 -08:00

710 lines
28 KiB
Python

import asyncio
import heapq
from abc import ABC, abstractmethod
from typing import Awaitable, Callable, Dict, Iterable, List, Optional, Union
import aiohttp
from transformers import AutoTokenizer
from tensorrt_llm.bindings.internal.batch_manager import (BlockKey,
BlockKeyHasher)
from tensorrt_llm.llmapi.disagg_utils import (MetadataServerConfig,
RouterConfig, ServerRole)
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.metadata_server import JsonDictionary
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
CompletionRequest)
OpenAIRequest = Union[CompletionRequest, ChatCompletionRequest]
def get_request_num_tokens(request: OpenAIRequest) -> int:
if request.disaggregated_params is None or request.disaggregated_params.request_type == "context_only":
if isinstance(request, ChatCompletionRequest):
raise ValueError(
"LoadBalancing router with tokens doesn't support ChatCompletionRequest yet"
)
if isinstance(request.prompt, str) or \
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
prompts = [request.prompt]
else:
prompts = request.prompt
num_tokens = sum(len(prompt) for prompt in prompts)
elif request.disaggregated_params.request_type == "generation_only":
raise ValueError(
"LoadBalancing router with tokens doesn't support generation_only requests"
)
else:
raise ValueError(
f"Unsupported request type: {request.disaggregated_params.request_type}"
)
return num_tokens
class ServerState:
def __init__(self, server: str, use_tokens: bool = False):
self._server = server
self._num_active_requests = 0
self._num_active_tokens = 0
self._use_tokens = use_tokens
self._lock = asyncio.Lock()
async def increment_load(self, request: OpenAIRequest):
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
async with self._lock:
self._num_active_requests += 1
self._num_active_tokens += num_tokens
async def decrement_load(self, request: OpenAIRequest):
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
async with self._lock:
self._num_active_requests -= 1
self._num_active_tokens -= num_tokens
async def is_healthy(self) -> bool:
try:
async with self._session.get(self._server + "/health") as response:
return response.status == 200
except Exception:
return False
class KvCacheAwareServerState(ServerState):
def __init__(self,
server: str,
use_tokens: bool = False,
tokens_per_block: int = 32):
super().__init__(server, use_tokens)
self._kv_cache_block_table: set[int] = set()
self._tokens_per_block = tokens_per_block
def add_blocks(self, block_hashes: Iterable[int]):
for hash in block_hashes:
self._kv_cache_block_table.add(hash)
def remove_blocks(self, block_hashes: Iterable[int]):
for hash in block_hashes:
self._kv_cache_block_table.discard(hash)
def update_with_events(self, events: Iterable[dict]):
# event_raw: {"id": <id>, "data": <event body>}
for event_raw in events:
if "data" in event_raw:
event = event_raw["data"]
else:
event = event_raw
if event["type"] == "stored":
self.add_blocks(block["block_hash"]
for block in event["blocks"])
elif event["type"] == "removed":
self.remove_blocks(event["block_hashes"])
async def poll_events(self, session: aiohttp.ClientSession):
async with session.post(self._server + "/kv_cache_events") as response:
events_raw = await response.json()
return events_raw
async def matched_tokens(self, block_hashes: list[list[int]]) -> int:
match_count = 0
async with self._lock:
for hash_list in block_hashes:
for hash in hash_list:
# TODO: 1) parent hash verification, 2) partial matching
if hash in self._kv_cache_block_table:
match_count += self._tokens_per_block
else:
break
return match_count
async def decrement_load(self,
request: OpenAIRequest,
session: Optional[aiohttp.ClientSession] = None):
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
if session is not None:
events_raw = await self.poll_events(session)
else:
events_raw = None
async with self._lock:
self._num_active_requests -= 1
self._num_active_tokens -= num_tokens
if events_raw is not None:
self.update_with_events(events_raw)
def num_active_tokens(self):
return self._num_active_tokens
def num_active_requests(self):
return self._num_active_requests
class Router(ABC):
def __init__(
self,
server_role: ServerRole,
servers: List[str],
metadata_server_cfg: Optional[MetadataServerConfig],
metadata_server: Optional[JsonDictionary],
server_preparation_func: Optional[Callable[[str],
Awaitable[None]]] = None,
**kwargs):
self._servers = servers or []
self._metadata_server = metadata_server
self._server_role = server_role
self._lock = asyncio.Lock()
self._monitor_task = None
self._session = None
self._health_check_timeout = metadata_server_cfg.health_check_timeout if metadata_server_cfg else None
self._server_preparation_func = server_preparation_func
@abstractmethod
def _on_servers_updated(self, old_servers, new_servers):
"""Called when the server list changes. Override in subclasses to handle index resets.
Called with lock already held.
Args:
old_servers: The previous server list
new_servers: The new server list
"""
@property
def servers(self) -> List[str]:
return self._servers
async def _prepare_server(self, server: str):
if self._server_preparation_func:
await self._server_preparation_func(server)
async def prepare_servers(self, servers: Optional[List[str]] = None):
for server in servers or self._servers:
await self._prepare_server(server)
async def add_server(self, server: str):
if server in self._servers:
logger.warning(f"Server {server} already exists")
return
await self._prepare_server(server)
async with self._lock:
old_servers = self._servers.copy()
self._servers = [*old_servers, server]
self._on_servers_updated(old_servers, self._servers)
logger.debug(
f"Added server {server}, {self._server_role.name} current server list: {self._servers}"
)
async def remove_server(self, server: str):
if server not in self._servers:
logger.warning(f"Server {server} does not exist")
return
async with self._lock:
old_servers = self._servers.copy()
self._servers = [
old_server for old_server in old_servers if old_server != server
]
self._on_servers_updated(old_servers, self._servers)
logger.debug(
f"Removed server {server}, current server list: {self._servers}")
@abstractmethod
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
'''Select server by request and return some intermediate information'''
@abstractmethod
async def finish_request(self, request: OpenAIRequest):
pass
async def start_server_monitoring(self, poll_interval: float = 10.0):
"""Start monitoring servers update from metadata service"""
if not self._metadata_server:
raise RuntimeError("Metadata server is not initialized")
# Create a session for health checks if it doesn't exist
if not self._session:
self._session = aiohttp.ClientSession()
logger.info(
f"Starting server monitoring for {self._server_role} servers")
self._monitor_task = asyncio.create_task(
self._monitor_servers(poll_interval))
async def stop_server_monitoring(self):
"""Stop monitoring servers update from metadata service"""
if self._monitor_task:
logger.info(
f"Stopping server monitoring for {self._server_role} servers")
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
self._monitor_task = None
# Close session when stopping monitoring
await self.close_session()
async def close_session(self):
if self._session:
try:
await self._session.close()
self._session = None
logger.debug("HTTP session closed")
except Exception as e:
logger.error(f"Error closing session: {e}")
self._session = None
async def _monitor_servers(self, poll_interval: float = 10.0):
while True:
try:
# Get servers from metadata
server_key_map = await self.fetch_live_servers()
# Check health and get live servers
live_servers = await self.check_servers_health(server_key_map)
# Filter by server role if needed
role_specific_servers = self._filter_servers_by_role(
live_servers, server_key_map)
# Use filtered servers if available
final_servers = role_specific_servers
assert final_servers, f"No {self._server_role} servers available"
# Update server list
async with self._lock:
if final_servers != self._servers:
old_servers = self._servers.copy()
self._servers = final_servers
# Call handler for server list changes
self._on_servers_updated(old_servers, self._servers)
# Log removed servers
for server in old_servers:
if server not in final_servers:
logger.info(f"Server {server} is removed")
# Log added servers
for server in final_servers:
if server not in old_servers:
await self._prepare_server(server)
logger.info(f"Server {server} is added")
else:
logger.debug(
f"No change in {self._server_role} server list: {len(self._servers)} servers"
)
except Exception as e:
logger.error(f"Error in server monitoring: {e}")
raise
# Wait before next poll
await asyncio.sleep(poll_interval)
def _filter_servers_by_role(self, servers, server_key_map):
"""Filter servers by role (context or generation)"""
if not servers:
raise RuntimeError("No servers available")
filtered_servers = []
# Invert to get {url: key} for lookup
url_to_key = {url: key for key, url in server_key_map.items()}
for server_url in servers:
key = url_to_key.get(server_url)
if key:
server_metadata = self._metadata_server.get(key)
if server_metadata:
server_type = self._get_server_type(server_metadata)
if self._is_matching_role(server_type):
filtered_servers.append(server_url)
return filtered_servers
def _get_server_type(self, server_metadata: dict) -> str:
return (server_metadata.get('server_type')
or server_metadata.get('server_role') or '').lower()
def _is_matching_role(self, server_type: str) -> bool:
return (self._server_role == ServerRole.CONTEXT and server_type == 'context') or \
(self._server_role == ServerRole.GENERATION and server_type == 'generation')
async def fetch_live_servers(self) -> Dict[str, str]:
"""Fetch all servers from metadata service and return {key: url} mapping"""
if not self._metadata_server:
raise RuntimeError("Metadata server is not initialized")
# If metadata server is available, ignore static server list entirely
server_key_map = {}
try:
# Get all keys from the metadata server
all_keys = self._metadata_server.keys()
logger.debug(f"Found {len(all_keys)} keys in metadata server")
# Filter keys that start with 'trtllm/' and extract server metadata
for key in all_keys:
if key.startswith('trtllm/'):
server_metadata = self._metadata_server.get(key)
if server_metadata and isinstance(
server_metadata, dict) and 'url' in server_metadata:
server_key_map[key] = server_metadata['url']
if server_key_map:
logger.debug(
f"Using {len(server_key_map)} servers from metadata service"
)
else:
raise ValueError("No servers found in metadata service")
except Exception as e:
logger.error(f"Error fetching servers from metadata service: {e}")
raise
return server_key_map
async def check_servers_health(self,
server_key_map: Dict[str, str]) -> List[str]:
"""Check health of servers and remove dead ones from metadata service"""
live_servers = []
dead_servers = []
# Check health of each server
for key, server_url in server_key_map.items():
try:
is_healthy = await self._check_server_health(server_url)
# If first attempt failed, try again before declaring server dead
if not is_healthy:
# Second attempt - will print errors if it fails
is_healthy = await self._check_server_health(server_url)
if not is_healthy:
# Only now add to dead servers
dead_servers.append((key, server_url))
else:
live_servers.append(server_url)
except Exception as e:
logger.error(
f"Error checking health for server {server_url} (key: {key}): {e}"
)
dead_servers.append((key, server_url))
# Remove dead servers from etcd
for key, dead_server in dead_servers:
try:
logger.info(
f"Removing dead server {dead_server} from metadata server")
self._metadata_server.remove(key)
except Exception as e:
logger.error(
f"Error removing dead server from metadata service: {e}")
raise
return live_servers
async def _check_server_health(self, server_url) -> bool:
"""Check if a server is healthy by querying its health endpoint"""
if not self._session:
self._session = aiohttp.ClientSession()
assert self._health_check_timeout is not None, "health_check_timeout is not set"
try:
async with self._session.get(
f"{server_url}/health",
timeout=self._health_check_timeout) as response:
if response.status != 200:
logger.warning(
f"Server {server_url} is not healthy (status: {response.status})"
)
return False
return True
except Exception as e:
logger.warning(f"Server {server_url} is not reachable: {e}")
return False
class RoundRobinRouter(Router):
def __init__(self,
server_role: ServerRole,
servers: List[str] = None,
metadata_server_cfg: MetadataServerConfig = None,
metadata_server: JsonDictionary = None,
**kwargs):
super().__init__(server_role, servers, metadata_server_cfg,
metadata_server, **kwargs)
self._server_idx = 0
def _on_servers_updated(self, old_servers, new_servers):
"""Reset the index when servers are removed to prevent index out of bounds errors."""
if len(new_servers) < len(old_servers):
# Servers were removed, reset the index
self._server_idx = 0
elif self._server_idx >= len(new_servers):
# Safety check: ensure index is always within bounds
self._server_idx = 0
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
if not self._servers:
if self._metadata_server:
raise ValueError(
f"No {self._server_role} servers available in metadata service"
)
else:
raise ValueError(f"No {self._server_role} servers available")
async with self._lock:
# Safety check: ensure index is within bounds
if self._server_idx >= len(self._servers):
self._server_idx = 0
server = self._servers[self._server_idx]
self._server_idx = (self._server_idx + 1) % len(self._servers)
return server, {}
async def finish_request(self, request: OpenAIRequest):
pass
class LoadBalancingRouter(Router):
def __init__(self,
server_role: ServerRole,
servers: List[str] = None,
metadata_server_cfg: MetadataServerConfig = None,
metadata_server: JsonDictionary = None,
use_tokens: bool = False,
**kwargs):
super().__init__(server_role, servers, metadata_server_cfg,
metadata_server, **kwargs)
# Load map between servers and their number of tokens processed
self._server_state = {}
self._server_load_heap = []
# Routing table to map requests to servers
self._req_routing_table = {}
self._use_tokens = use_tokens
self._init_heap()
def _on_servers_updated(self, old_servers, new_servers):
"""Rebuild the heap when the server list changes."""
# Keep the state for servers that still exist
current_state = {}
for server in new_servers:
if server in self._server_state:
# Keep existing state
current_state[server] = self._server_state[server]
else:
# Initialize new server state
current_state[server] = ServerState(server, self._use_tokens)
# Update state and rebuild heap
self._server_state = current_state
self._server_load_heap = []
for server in new_servers:
heapq.heappush(self._server_load_heap,
(self._get_server_load(server), server))
def _init_heap(self):
for server in self._servers:
self._server_state[server] = ServerState(server, self._use_tokens)
heapq.heappush(self._server_load_heap,
(self._get_server_load(server), server))
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
if not self._servers:
if self._metadata_server:
raise ValueError(
f"No {self._server_role} servers available in metadata service"
)
else:
raise ValueError(f"No {self._server_role} servers available")
async with self._lock:
server = heapq.heappop(self._server_load_heap)[1]
await self._server_state[server].increment_load(request)
heapq.heappush(self._server_load_heap,
(self._get_server_load(server), server))
self._req_routing_table[id(request)] = server
return server, {}
def _get_server_load(self, server):
return self._server_state[server]._num_active_tokens if self._use_tokens \
else self._server_state[server]._num_active_requests
async def finish_request(self, request: OpenAIRequest):
async with self._lock:
server = self._req_routing_table[id(request)]
await self._server_state[server].decrement_load(request)
heapq.heappush(self._server_load_heap,
(self._get_server_load(server), server))
del self._req_routing_table[id(request)]
def block_key_hasher(token_ids: list[int],
parent_hash: Optional[int] = None) -> int:
block_key = BlockKey(token_ids)
return BlockKeyHasher.hash(block_key,
0 if parent_hash is None else parent_hash)
class KvCacheAwareRouter(Router):
def __init__(self,
server_role: ServerRole = None,
servers: list[str] = None,
metadata_server_cfg: MetadataServerConfig = None,
metadata_server: JsonDictionary = None,
use_tokens: bool = False,
max_batch_size: int = 64,
tokens_per_block: int = 32,
**kwargs):
super().__init__(server_role, servers, metadata_server_cfg,
metadata_server, **kwargs)
self._lock = asyncio.Lock()
self._use_tokens = use_tokens
# Load map between servers and their number of tokens processed
self._server_state: dict[str, KvCacheAwareServerState] = {
server: KvCacheAwareServerState(server, use_tokens)
for server in servers or []
}
# Routing table to map requests to servers
self._req_routing_table: dict[int, OpenAIRequest] = {}
self._tokenizers = {}
# TODO: use max_num_tokens? per server?
self._max_batch_size = max_batch_size
self._tokens_per_block = tokens_per_block
def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
prompts = request.prompt
if isinstance(prompts, list) and isinstance(prompts[0], list):
return prompts
elif isinstance(prompts, list) and isinstance(prompts[0], int):
return [prompts]
elif isinstance(prompts, str):
prompts = [prompts]
else:
assert isinstance(prompts, list) and isinstance(prompts[0], str)
# TODO: send tokenize-only request instead of tokenizing locally
if request.model not in self._tokenizers:
self._tokenizers[request.model] = AutoTokenizer.from_pretrained(
request.model)
tokenizer = self._tokenizers[request.model]
return [tokenizer(prompt)["input_ids"] for prompt in prompts]
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
async with self._lock:
servers = list(self._server_state.keys())
token_lists = self._tokenize(request)
block_hashes: list[list[int]] = []
for token_list in token_lists:
hash_list = []
# in KvCacheManager, the last token is not included in the block key
for t in range(0, len(token_list) - 1, self._tokens_per_block):
t_end = min(t + self._tokens_per_block, len(token_list) - 1)
hash_list.append(
block_key_hasher(token_list[t:t_end],
None if t == 0 else hash_list[-1]))
block_hashes.append(hash_list)
padded_tokens = sum(
len(hash_list)
for hash_list in block_hashes) * self._tokens_per_block
# select the server by (KV match - load)
# TODO: more options
workloads = [
state.num_active_requests()
for state in self._server_state.values()
]
scores = []
matches = []
for i in range(len(servers)):
server = servers[i]
# https://github.com/ai-dynamo/dynamo/blob/main/docs/kv_cache_routing.md#kv-cache-routing-and-load-balancing
matches.append(
await self._server_state[server].matched_tokens(block_hashes))
score = matches[-1] / padded_tokens - workloads[
i] / self._max_batch_size
scores.append(score)
server = servers[scores.index(max(scores))]
async with self._lock:
await self._server_state[server].increment_load(request)
self._req_routing_table[id(request)] = server
return server, {
"block_hashes": block_hashes, # list[list[int]]
"token_lists": token_lists, # list[list[int]]
"matches": matches, # list[int]
}
async def finish_request(self,
request: OpenAIRequest,
session: Optional[aiohttp.ClientSession] = None):
async with self._lock:
server = self._req_routing_table[id(request)]
del self._req_routing_table[id(request)]
if server in self._server_state:
await self._server_state[server].decrement_load(request,
session=session)
def _on_servers_updated(self, old_servers, new_servers):
for new_server in new_servers:
self._server_state[new_server] = KvCacheAwareServerState(
new_server, self._use_tokens)
for old_server in old_servers:
self._server_state.pop(old_server, None)
def create_router(
router_config: Optional[RouterConfig],
servers: Optional[List[str]],
metadata_server_cfg: Optional[MetadataServerConfig] = None,
metadata_server: Optional[JsonDictionary] = None,
server_preparation_func: Optional[Callable[[str], Awaitable[None]]] = None
) -> Router:
"""
Factory function to create different types of router instances.
Args:
router_type (str): Type of router to create. Supported values:
- "round_robin": Creates a RoundRobinRouter (default)
- "load_balancing": Creates a LoadBalancingRouter, which balances requests or tokens across instances
- "kv_cache_aware": Creates a KvCacheAwareRouter, which balances requests across instances additionally based on KV cache hits
servers: List of server URLs
Returns:
Router: An instance of the requested router type
Raises:
ValueError: If an unsupported router type is provided
"""
router_map = {
"round_robin": RoundRobinRouter,
"load_balancing": LoadBalancingRouter,
"kv_cache_aware": KvCacheAwareRouter,
}
router_type = router_config.type if router_config else "round_robin"
router_class = router_map.get(router_type.lower())
if router_class is None:
raise ValueError(f"Unsupported router type: {router_type}. "
f"Supported types are: {list(router_map.keys())}")
extra_args = router_config.args if router_config else {}
return router_class(router_config.server_role if router_config else None,
servers,
metadata_server_cfg,
metadata_server,
server_preparation_func=server_preparation_func,
**extra_args)