import asyncio import heapq from abc import ABC, abstractmethod from typing import Dict, Iterable, List, Optional, Union import aiohttp from transformers import AutoTokenizer from tensorrt_llm.bindings.internal.batch_manager import (BlockKey, BlockKeyHasher) from tensorrt_llm.llmapi.disagg_utils import (MetadataServerConfig, RouterConfig, ServerRole) from tensorrt_llm.logger import logger from tensorrt_llm.serve.metadata_server import JsonDictionary from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, CompletionRequest) OpenAIRequest = Union[CompletionRequest, ChatCompletionRequest] def get_request_num_tokens(request: OpenAIRequest) -> int: if request.disaggregated_params is None or request.disaggregated_params.request_type == "context_only": if isinstance(request, ChatCompletionRequest): raise ValueError( "LoadBalancing router with tokens doesn't support ChatCompletionRequest yet" ) if isinstance(request.prompt, str) or \ (isinstance(request.prompt, list) and isinstance(request.prompt[0], int)): prompts = [request.prompt] else: prompts = request.prompt num_tokens = sum(len(prompt) for prompt in prompts) elif request.disaggregated_params.request_type == "generation_only": raise ValueError( "LoadBalancing router with tokens doesn't support generation_only requests" ) else: raise ValueError( f"Unsupported request type: {request.disaggregated_params.request_type}" ) return num_tokens class ServerState: def __init__(self, server: str, use_tokens: bool = False): self._server = server self._num_active_requests = 0 self._num_active_tokens = 0 self._use_tokens = use_tokens self._lock = asyncio.Lock() async def increment_load(self, request: OpenAIRequest): num_tokens = get_request_num_tokens(request) if self._use_tokens else 0 async with self._lock: self._num_active_requests += 1 self._num_active_tokens += num_tokens async def decrement_load(self, request: OpenAIRequest): num_tokens = get_request_num_tokens(request) if self._use_tokens else 0 async with self._lock: self._num_active_requests -= 1 self._num_active_tokens -= num_tokens async def is_healthy(self) -> bool: try: async with self._session.get(self._server + "/health") as response: return response.status == 200 except Exception: return False class KvCacheAwareServerState(ServerState): def __init__(self, server: str, use_tokens: bool = False, tokens_per_block: int = 32): super().__init__(server, use_tokens) self._kv_cache_block_table: set[int] = set() self._tokens_per_block = tokens_per_block def add_blocks(self, block_hashes: Iterable[int]): for hash in block_hashes: self._kv_cache_block_table.add(hash) def remove_blocks(self, block_hashes: Iterable[int]): for hash in block_hashes: self._kv_cache_block_table.discard(hash) def update_with_events(self, events: Iterable[dict]): # event_raw: {"id": , "data": } for event_raw in events: if "data" in event_raw: event = event_raw["data"] else: event = event_raw if event["type"] == "stored": self.add_blocks(block["block_hash"] for block in event["blocks"]) elif event["type"] == "removed": self.remove_blocks(event["block_hashes"]) async def poll_events(self, session: aiohttp.ClientSession): async with session.post(self._server + "/kv_cache_events") as response: events_raw = await response.json() return events_raw async def matched_tokens(self, block_hashes: list[list[int]]) -> int: match_count = 0 async with self._lock: for hash_list in block_hashes: for hash in hash_list: # TODO: 1) parent hash verification, 2) partial matching if hash in self._kv_cache_block_table: match_count += self._tokens_per_block else: break return match_count async def decrement_load(self, request: OpenAIRequest, session: Optional[aiohttp.ClientSession] = None): num_tokens = get_request_num_tokens(request) if self._use_tokens else 0 if session is not None: events_raw = await self.poll_events(session) else: events_raw = None async with self._lock: self._num_active_requests -= 1 self._num_active_tokens -= num_tokens if events_raw is not None: self.update_with_events(events_raw) def num_active_tokens(self): return self._num_active_tokens def num_active_requests(self): return self._num_active_requests class Router(ABC): def __init__(self, server_role: ServerRole, servers: List[str], metadata_server_cfg: Optional[MetadataServerConfig], metadata_server: Optional[JsonDictionary]): self._servers = servers or [] self._metadata_server = metadata_server self._server_role = server_role self._lock = asyncio.Lock() self._monitor_task = None self._session = None self._health_check_timeout = metadata_server_cfg.health_check_timeout if metadata_server_cfg else None @abstractmethod def _on_servers_updated(self, old_servers, new_servers): """Called when the server list changes. Override in subclasses to handle index resets. Args: old_servers: The previous server list new_servers: The new server list """ @property def servers(self) -> List[str]: return self._servers async def add_server(self, server: str): if server in self._servers: logger.warning(f"Server {server} already exists") return async with self._lock: old_servers = self._servers.copy() self._servers = [*old_servers, server] self._on_servers_updated(old_servers, self._servers) logger.debug( f"Added server {server}, current server list: {self._servers}") async def remove_server(self, server: str): if server not in self._servers: logger.warning(f"Server {server} does not exist") return async with self._lock: old_servers = self._servers.copy() self._servers = [ old_server for old_server in old_servers if old_server != server ] self._on_servers_updated(old_servers, self._servers) logger.debug( f"Removed server {server}, current server list: {self._servers}") @abstractmethod async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]: '''Select server by request and return some intermediate information''' @abstractmethod async def finish_request(self, request: OpenAIRequest): pass async def start_server_monitoring(self, poll_interval: float = 10.0): """Start monitoring servers update from metadata service""" if not self._metadata_server: raise RuntimeError("Metadata server is not initialized") # Create a session for health checks if it doesn't exist if not self._session: self._session = aiohttp.ClientSession() logger.info( f"Starting server monitoring for {self._server_role} servers") self._monitor_task = asyncio.create_task( self._monitor_servers(poll_interval)) async def stop_server_monitoring(self): """Stop monitoring servers update from metadata service""" if self._monitor_task: logger.info( f"Stopping server monitoring for {self._server_role} servers") self._monitor_task.cancel() try: await self._monitor_task except asyncio.CancelledError: pass self._monitor_task = None # Close session when stopping monitoring await self.close_session() async def close_session(self): if self._session: try: await self._session.close() self._session = None logger.debug("HTTP session closed") except Exception as e: logger.error(f"Error closing session: {e}") self._session = None async def _monitor_servers(self, poll_interval: float = 10.0): while True: try: # Get servers from metadata server_key_map = await self.fetch_live_servers() # Check health and get live servers live_servers = await self.check_servers_health(server_key_map) # Filter by server role if needed role_specific_servers = self._filter_servers_by_role( live_servers, server_key_map) # Use filtered servers if available final_servers = role_specific_servers assert final_servers, f"No {self._server_role} servers available" # Update server list async with self._lock: if final_servers != self._servers: old_servers = self._servers.copy() self._servers = final_servers # Call handler for server list changes self._on_servers_updated(old_servers, self._servers) # Log removed servers for server in old_servers: if server not in final_servers: logger.info(f"Server {server} is removed") # Log added servers for server in final_servers: if server not in old_servers: logger.info(f"Server {server} is added") else: logger.debug( f"No change in {self._server_role} server list: {len(self._servers)} servers" ) except Exception as e: logger.error(f"Error in server monitoring: {e}") raise # Wait before next poll await asyncio.sleep(poll_interval) def _filter_servers_by_role(self, servers, server_key_map): """Filter servers by role (context or generation)""" if not servers: raise RuntimeError("No servers available") filtered_servers = [] # Invert to get {url: key} for lookup url_to_key = {url: key for key, url in server_key_map.items()} for server_url in servers: key = url_to_key.get(server_url) if key: server_metadata = self._metadata_server.get(key) if server_metadata: server_type = self._get_server_type(server_metadata) if self._is_matching_role(server_type): filtered_servers.append(server_url) return filtered_servers def _get_server_type(self, server_metadata: dict) -> str: return (server_metadata.get('server_type') or server_metadata.get('server_role') or '').lower() def _is_matching_role(self, server_type: str) -> bool: return (self._server_role == ServerRole.CONTEXT and server_type == 'context') or \ (self._server_role == ServerRole.GENERATION and server_type == 'generation') async def fetch_live_servers(self) -> Dict[str, str]: """Fetch all servers from metadata service and return {key: url} mapping""" if not self._metadata_server: raise RuntimeError("Metadata server is not initialized") # If metadata server is available, ignore static server list entirely server_key_map = {} try: # Get all keys from the metadata server all_keys = self._metadata_server.keys() logger.debug(f"Found {len(all_keys)} keys in metadata server") # Filter keys that start with 'trtllm/' and extract server metadata for key in all_keys: if key.startswith('trtllm/'): server_metadata = self._metadata_server.get(key) if server_metadata and isinstance( server_metadata, dict) and 'url' in server_metadata: server_key_map[key] = server_metadata['url'] if server_key_map: logger.debug( f"Using {len(server_key_map)} servers from metadata service" ) else: raise ValueError("No servers found in metadata service") except Exception as e: logger.error(f"Error fetching servers from metadata service: {e}") raise return server_key_map async def check_servers_health(self, server_key_map: Dict[str, str]) -> List[str]: """Check health of servers and remove dead ones from metadata service""" live_servers = [] dead_servers = [] # Check health of each server for key, server_url in server_key_map.items(): try: is_healthy = await self._check_server_health(server_url) # If first attempt failed, try again before declaring server dead if not is_healthy: # Second attempt - will print errors if it fails is_healthy = await self._check_server_health(server_url) if not is_healthy: # Only now add to dead servers dead_servers.append((key, server_url)) else: live_servers.append(server_url) except Exception as e: logger.error( f"Error checking health for server {server_url} (key: {key}): {e}" ) dead_servers.append((key, server_url)) # Remove dead servers from etcd for key, dead_server in dead_servers: try: logger.info( f"Removing dead server {dead_server} from metadata server") self._metadata_server.remove(key) except Exception as e: logger.error( f"Error removing dead server from metadata service: {e}") raise return live_servers async def _check_server_health(self, server_url) -> bool: """Check if a server is healthy by querying its health endpoint""" if not self._session: self._session = aiohttp.ClientSession() assert self._health_check_timeout is not None, "health_check_timeout is not set" try: async with self._session.get( f"{server_url}/health", timeout=self._health_check_timeout) as response: if response.status != 200: logger.warning( f"Server {server_url} is not healthy (status: {response.status})" ) return False return True except Exception as e: logger.warning(f"Server {server_url} is not reachable: {e}") return False class RoundRobinRouter(Router): def __init__(self, server_role: ServerRole, servers: List[str] = None, metadata_server_cfg: MetadataServerConfig = None, metadata_server: JsonDictionary = None, **kwargs): super().__init__(server_role, servers, metadata_server_cfg, metadata_server) self._server_idx = 0 def _on_servers_updated(self, old_servers, new_servers): """Reset the index when servers are removed to prevent index out of bounds errors.""" if len(new_servers) < len(old_servers): # Servers were removed, reset the index self._server_idx = 0 elif self._server_idx >= len(new_servers): # Safety check: ensure index is always within bounds self._server_idx = 0 async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]: if not self._servers: if self._metadata_server: raise ValueError( f"No {self._server_role} servers available in metadata service" ) else: raise ValueError(f"No {self._server_role} servers available") async with self._lock: # Safety check: ensure index is within bounds if self._server_idx >= len(self._servers): self._server_idx = 0 server = self._servers[self._server_idx] self._server_idx = (self._server_idx + 1) % len(self._servers) return server, {} async def finish_request(self, request: OpenAIRequest): pass class LoadBalancingRouter(Router): def __init__(self, server_role: ServerRole, servers: List[str] = None, metadata_server_cfg: MetadataServerConfig = None, metadata_server: JsonDictionary = None, use_tokens: bool = False, **kwargs): super().__init__(server_role, servers, metadata_server_cfg, metadata_server) # Load map between servers and their number of tokens processed self._server_state = {} self._server_load_heap = [] # Routing table to map requests to servers self._req_routing_table = {} self._use_tokens = use_tokens self._init_heap() def _on_servers_updated(self, old_servers, new_servers): """Rebuild the heap when the server list changes.""" # Keep the state for servers that still exist current_state = {} for server in new_servers: if server in self._server_state: # Keep existing state current_state[server] = self._server_state[server] else: # Initialize new server state current_state[server] = ServerState(server, self._use_tokens) # Update state and rebuild heap self._server_state = current_state self._server_load_heap = [] for server in new_servers: heapq.heappush(self._server_load_heap, (self._get_server_load(server), server)) def _init_heap(self): for server in self._servers: self._server_state[server] = ServerState(server, self._use_tokens) heapq.heappush(self._server_load_heap, (self._get_server_load(server), server)) async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]: if not self._servers: if self._metadata_server: raise ValueError( f"No {self._server_role} servers available in metadata service" ) else: raise ValueError(f"No {self._server_role} servers available") async with self._lock: server = heapq.heappop(self._server_load_heap)[1] await self._server_state[server].increment_load(request) heapq.heappush(self._server_load_heap, (self._get_server_load(server), server)) self._req_routing_table[id(request)] = server return server, {} def _get_server_load(self, server): return self._server_state[server]._num_active_tokens if self._use_tokens \ else self._server_state[server]._num_active_requests async def finish_request(self, request: OpenAIRequest): async with self._lock: server = self._req_routing_table[id(request)] await self._server_state[server].decrement_load(request) heapq.heappush(self._server_load_heap, (self._get_server_load(server), server)) del self._req_routing_table[id(request)] def block_key_hasher(token_ids: list[int], parent_hash: Optional[int] = None) -> int: block_key = BlockKey(token_ids) return BlockKeyHasher.hash(block_key, 0 if parent_hash is None else parent_hash) class KvCacheAwareRouter(Router): def __init__(self, server_role: ServerRole = None, servers: list[str] = None, metadata_server_cfg: MetadataServerConfig = None, metadata_server: JsonDictionary = None, use_tokens: bool = False, max_batch_size: int = 64, tokens_per_block: int = 32, **kwargs): super().__init__(server_role, servers, metadata_server_cfg, metadata_server) self._lock = asyncio.Lock() self._use_tokens = use_tokens # Load map between servers and their number of tokens processed self._server_state: dict[str, KvCacheAwareServerState] = { server: KvCacheAwareServerState(server, use_tokens) for server in servers or [] } # Routing table to map requests to servers self._req_routing_table: dict[int, OpenAIRequest] = {} self._tokenizers = {} # TODO: use max_num_tokens? per server? self._max_batch_size = max_batch_size self._tokens_per_block = tokens_per_block def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: prompts = request.prompt if isinstance(prompts, list) and isinstance(prompts[0], list): return prompts elif isinstance(prompts, list) and isinstance(prompts[0], int): return [prompts] elif isinstance(prompts, str): prompts = [prompts] else: assert isinstance(prompts, list) and isinstance(prompts[0], str) # TODO: send tokenize-only request instead of tokenizing locally if request.model not in self._tokenizers: self._tokenizers[request.model] = AutoTokenizer.from_pretrained( request.model) tokenizer = self._tokenizers[request.model] return [tokenizer(prompt)["input_ids"] for prompt in prompts] async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]: async with self._lock: servers = list(self._server_state.keys()) token_lists = self._tokenize(request) block_hashes: list[list[int]] = [] for token_list in token_lists: hash_list = [] # in KvCacheManager, the last token is not included in the block key for t in range(0, len(token_list) - 1, self._tokens_per_block): t_end = min(t + self._tokens_per_block, len(token_list) - 1) hash_list.append( block_key_hasher(token_list[t:t_end], None if t == 0 else hash_list[-1])) block_hashes.append(hash_list) padded_tokens = sum( len(hash_list) for hash_list in block_hashes) * self._tokens_per_block # select the server by (KV match - load) # TODO: more options workloads = [ state.num_active_requests() for state in self._server_state.values() ] scores = [] matches = [] for i in range(len(servers)): server = servers[i] # https://github.com/ai-dynamo/dynamo/blob/main/docs/kv_cache_routing.md#kv-cache-routing-and-load-balancing matches.append( await self._server_state[server].matched_tokens(block_hashes)) score = matches[-1] / padded_tokens - workloads[ i] / self._max_batch_size scores.append(score) server = servers[scores.index(max(scores))] async with self._lock: await self._server_state[server].increment_load(request) self._req_routing_table[id(request)] = server return server, { "block_hashes": block_hashes, # list[list[int]] "token_lists": token_lists, # list[list[int]] "matches": matches, # list[int] } async def finish_request(self, request: OpenAIRequest, session: Optional[aiohttp.ClientSession] = None): async with self._lock: server = self._req_routing_table[id(request)] del self._req_routing_table[id(request)] if server in self._server_state: await self._server_state[server].decrement_load(request, session=session) def _on_servers_updated(self, old_servers, new_servers): raise NotImplementedError( "KvCacheAwareRouter does not support server updates") def create_router(router_config: Optional[RouterConfig], servers: Optional[List[str]], metadata_server_cfg: Optional[MetadataServerConfig] = None, metadata_server: Optional[JsonDictionary] = None) -> Router: """ Factory function to create different types of router instances. Args: router_type (str): Type of router to create. Supported values: - "round_robin": Creates a RoundRobinRouter (default) - "load_balancing": Creates a LoadBalancingRouter, which balances requests or tokens across instances - "kv_cache_aware": Creates a KvCacheAwareRouter, which balances requests across instances additionally based on KV cache hits servers: List of server URLs Returns: Router: An instance of the requested router type Raises: ValueError: If an unsupported router type is provided """ router_map = { "round_robin": RoundRobinRouter, "load_balancing": LoadBalancingRouter, "kv_cache_aware": KvCacheAwareRouter, } router_type = router_config.type if router_config else "round_robin" router_class = router_map.get(router_type.lower()) if router_class is None: raise ValueError(f"Unsupported router type: {router_type}. " f"Supported types are: {list(router_map.keys())}") extra_args = router_config.args if router_config else {} return router_class(router_config.server_role if router_config else None, servers, metadata_server_cfg, metadata_server, **extra_args)