mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-22 11:42:41 +08:00
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Signed-off-by: BatshevaBlack <132911331+BatshevaBlack@users.noreply.github.com> Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Co-authored-by: Batsheva Black <bblack@login-eos01.eos.clusters.nvidia.com> Co-authored-by: BatshevaBlack <132911331+BatshevaBlack@users.noreply.github.com>
668 lines
26 KiB
Python
668 lines
26 KiB
Python
import asyncio
|
|
import heapq
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Iterable, List, Optional, Union
|
|
|
|
import aiohttp
|
|
from transformers import AutoTokenizer
|
|
|
|
from tensorrt_llm.bindings.internal.batch_manager import (BlockKey,
|
|
BlockKeyHasher)
|
|
from tensorrt_llm.llmapi.disagg_utils import RouterConfig, ServerRole
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.serve.metadata_server import JsonDictionary
|
|
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
|
CompletionRequest)
|
|
|
|
OpenAIRequest = Union[CompletionRequest, ChatCompletionRequest]
|
|
|
|
|
|
def get_request_num_tokens(request: OpenAIRequest) -> int:
|
|
if request.disaggregated_params is None or request.disaggregated_params.request_type == "context_only":
|
|
if isinstance(request, ChatCompletionRequest):
|
|
raise ValueError(
|
|
"LoadBalancing router with tokens doesn't support ChatCompletionRequest yet"
|
|
)
|
|
|
|
if isinstance(request.prompt, str) or \
|
|
(isinstance(request.prompt, list) and isinstance(request.prompt[0], int)):
|
|
prompts = [request.prompt]
|
|
else:
|
|
prompts = request.prompt
|
|
|
|
num_tokens = sum(len(prompt) for prompt in prompts)
|
|
elif request.disaggregated_params.request_type == "generation_only":
|
|
raise ValueError(
|
|
"LoadBalancing router with tokens doesn't support generation_only requests"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported request type: {request.disaggregated_params.request_type}"
|
|
)
|
|
|
|
return num_tokens
|
|
|
|
|
|
class ServerState:
|
|
|
|
def __init__(self, server: str, use_tokens: bool = False):
|
|
self._server = server
|
|
self._num_active_requests = 0
|
|
self._num_active_tokens = 0
|
|
self._use_tokens = use_tokens
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def increment_load(self, request: OpenAIRequest):
|
|
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
|
|
async with self._lock:
|
|
self._num_active_requests += 1
|
|
self._num_active_tokens += num_tokens
|
|
|
|
async def decrement_load(self, request: OpenAIRequest):
|
|
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
|
|
async with self._lock:
|
|
self._num_active_requests -= 1
|
|
self._num_active_tokens -= num_tokens
|
|
|
|
async def is_healthy(self) -> bool:
|
|
try:
|
|
async with self._session.get(self._server + "/health") as response:
|
|
return response.status == 200
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
class KvCacheAwareServerState(ServerState):
|
|
|
|
def __init__(self,
|
|
server: str,
|
|
use_tokens: bool = False,
|
|
tokens_per_block: int = 32):
|
|
super().__init__(server, use_tokens)
|
|
self._kv_cache_block_table: set[int] = set()
|
|
self._tokens_per_block = tokens_per_block
|
|
|
|
def add_blocks(self, block_hashes: Iterable[int]):
|
|
for hash in block_hashes:
|
|
self._kv_cache_block_table.add(hash)
|
|
|
|
def remove_blocks(self, block_hashes: Iterable[int]):
|
|
for hash in block_hashes:
|
|
self._kv_cache_block_table.discard(hash)
|
|
|
|
def update_with_events(self, events: Iterable[dict]):
|
|
# event_raw: {"id": <id>, "data": <event body>}
|
|
for event_raw in events:
|
|
if "data" in event_raw:
|
|
event = event_raw["data"]
|
|
else:
|
|
event = event_raw
|
|
|
|
if event["type"] == "stored":
|
|
self.add_blocks(block["block_hash"]
|
|
for block in event["blocks"])
|
|
elif event["type"] == "removed":
|
|
self.remove_blocks(event["block_hashes"])
|
|
|
|
async def poll_events(self, session: aiohttp.ClientSession):
|
|
async with session.post(self._server + "/kv_cache_events") as response:
|
|
events_raw = await response.json()
|
|
return events_raw
|
|
|
|
async def matched_tokens(self, block_hashes: list[list[int]]) -> int:
|
|
match_count = 0
|
|
async with self._lock:
|
|
for hash_list in block_hashes:
|
|
for hash in hash_list:
|
|
# TODO: 1) parent hash verification, 2) partial matching
|
|
if hash in self._kv_cache_block_table:
|
|
match_count += self._tokens_per_block
|
|
else:
|
|
break
|
|
return match_count
|
|
|
|
async def decrement_load(self,
|
|
request: OpenAIRequest,
|
|
session: Optional[aiohttp.ClientSession] = None):
|
|
num_tokens = get_request_num_tokens(request) if self._use_tokens else 0
|
|
if session is not None:
|
|
events_raw = await self.poll_events(session)
|
|
else:
|
|
events_raw = None
|
|
async with self._lock:
|
|
self._num_active_requests -= 1
|
|
self._num_active_tokens -= num_tokens
|
|
if events_raw is not None:
|
|
self.update_with_events(events_raw)
|
|
|
|
def num_active_tokens(self):
|
|
return self._num_active_tokens
|
|
|
|
def num_active_requests(self):
|
|
return self._num_active_requests
|
|
|
|
|
|
class Router(ABC):
|
|
|
|
def __init__(self,
|
|
server_role: ServerRole,
|
|
servers: List[str] = None,
|
|
metadata_server: JsonDictionary = None):
|
|
self._servers = servers or []
|
|
self._metadata_server = metadata_server
|
|
self._server_role = server_role
|
|
self._lock = asyncio.Lock()
|
|
self._monitor_task = None
|
|
self._session = None
|
|
self._health_check_timeout = 5.0 # Default timeout in seconds
|
|
|
|
def _on_servers_updated(self, old_servers, new_servers):
|
|
"""Called when the server list changes. Override in subclasses to handle index resets.
|
|
Args:
|
|
old_servers: The previous server list
|
|
new_servers: The new server list
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
|
'''Select server by request and return some intermediate information'''
|
|
|
|
@abstractmethod
|
|
async def finish_request(self, request: OpenAIRequest):
|
|
pass
|
|
|
|
async def start_server_monitoring(self, poll_interval: int = 10):
|
|
"""Start monitoring servers update from metadata service"""
|
|
if not self._metadata_server:
|
|
logger.info(
|
|
"No metadata server configured, skipping server monitoring")
|
|
return
|
|
|
|
# Create a session for health checks if it doesn't exist
|
|
if not self._session:
|
|
self._session = aiohttp.ClientSession()
|
|
|
|
logger.info(
|
|
f"Starting server monitoring for {self._server_role} servers")
|
|
self._monitor_task = asyncio.create_task(
|
|
self._monitor_servers(poll_interval))
|
|
|
|
async def stop_server_monitoring(self):
|
|
"""Stop monitoring servers update from metadata service"""
|
|
if self._monitor_task:
|
|
logger.info(
|
|
f"Stopping server monitoring for {self._server_role} servers")
|
|
self._monitor_task.cancel()
|
|
try:
|
|
await self._monitor_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._monitor_task = None
|
|
|
|
# Close session when stopping monitoring
|
|
await self.close_session()
|
|
|
|
async def close_session(self):
|
|
if self._session:
|
|
try:
|
|
await self._session.close()
|
|
self._session = None
|
|
logger.debug("HTTP session closed")
|
|
except Exception as e:
|
|
logger.error(f"Error closing session: {e}")
|
|
self._session = None
|
|
|
|
async def _monitor_servers(self, poll_interval: int = 10):
|
|
while True:
|
|
try:
|
|
if self._metadata_server:
|
|
# Get servers from metadata
|
|
server_key_map = await self.fetch_live_servers()
|
|
|
|
# Check health and get live servers
|
|
live_servers = await self.check_servers_health(
|
|
server_key_map)
|
|
|
|
# Filter by server role if needed
|
|
role_specific_servers = self._filter_servers_by_role(
|
|
live_servers, server_key_map)
|
|
|
|
# Use filtered servers if available
|
|
final_servers = role_specific_servers if role_specific_servers else []
|
|
|
|
# Update server list
|
|
async with self._lock:
|
|
if final_servers != self._servers:
|
|
num_old_servers = len(self._servers)
|
|
old_servers = self._servers.copy()
|
|
self._servers = final_servers
|
|
num_new_servers = len(self._servers)
|
|
|
|
# Call handler for server list changes
|
|
self._on_servers_updated(old_servers, self._servers)
|
|
|
|
logger.info(
|
|
f"Updated {self._server_role} server list: {num_old_servers} -> {num_new_servers} servers"
|
|
)
|
|
if logger.level == "debug" and self._servers:
|
|
for server in self._servers:
|
|
logger.debug(f" - {server}")
|
|
else:
|
|
logger.debug(
|
|
f"No change in {self._server_role} server list: {len(self._servers)} servers"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error in server monitoring: {e}")
|
|
|
|
# Wait before next poll
|
|
await asyncio.sleep(poll_interval)
|
|
|
|
def _filter_servers_by_role(self, servers, server_key_map):
|
|
"""Filter servers by role (context or generation)"""
|
|
if not self._metadata_server or not servers:
|
|
return []
|
|
|
|
filtered_servers = []
|
|
# Invert to get {url: key} for lookup
|
|
url_to_key = {url: key for key, url in server_key_map.items()}
|
|
|
|
for server_url in servers:
|
|
key = url_to_key.get(server_url)
|
|
if key:
|
|
server_metadata = self._metadata_server.get(key)
|
|
if server_metadata:
|
|
# Use either server_type or server_role field
|
|
server_type = server_metadata.get('server_type', '').lower()
|
|
if not server_type:
|
|
server_type = server_metadata.get('server_role',
|
|
'').lower()
|
|
|
|
# Extract port for visibility
|
|
parts = server_url.split(':')
|
|
if len(parts) >= 3:
|
|
parts[2]
|
|
|
|
# Check if server type matches our role
|
|
if (self._server_role == ServerRole.CONTEXT and server_type == 'context') or \
|
|
(self._server_role == ServerRole.GENERATION and server_type == 'generation'):
|
|
filtered_servers.append(server_url)
|
|
|
|
return filtered_servers
|
|
|
|
async def fetch_live_servers(self) -> Dict[str, str]:
|
|
"""Fetch all servers from metadata service and return {key: url} mapping"""
|
|
if not self._metadata_server:
|
|
# Only use static servers if no metadata server
|
|
return {server: "" for server in self._servers}
|
|
|
|
# If metadata server is available, ignore static server list entirely
|
|
server_key_map = {}
|
|
try:
|
|
# Get all keys from the metadata server
|
|
all_keys = self._metadata_server.keys()
|
|
logger.debug(f"Found {len(all_keys)} keys in metadata server")
|
|
|
|
# Filter keys that start with 'trtllm/' and extract server metadata
|
|
matching_keys = 0
|
|
for key in all_keys:
|
|
if key.startswith('trtllm/'):
|
|
matching_keys += 1
|
|
server_metadata = self._metadata_server.get(key)
|
|
if server_metadata and isinstance(
|
|
server_metadata, dict) and 'url' in server_metadata:
|
|
server_key_map[key] = server_metadata['url']
|
|
|
|
# Check if metadata includes health check timeout
|
|
if 'health_check_timeout' in server_metadata:
|
|
try:
|
|
self._health_check_timeout = float(
|
|
server_metadata['health_check_timeout'])
|
|
logger.debug(
|
|
f"Using health check timeout: {self._health_check_timeout}s"
|
|
)
|
|
except (ValueError, TypeError):
|
|
logger.warning(
|
|
f"Invalid health_check_timeout value: {server_metadata['health_check_timeout']}"
|
|
)
|
|
|
|
if server_key_map:
|
|
logger.info(
|
|
f"Using {len(server_key_map)} servers from metadata service"
|
|
)
|
|
else:
|
|
logger.warning("No servers found in metadata service")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching servers from metadata service: {e}")
|
|
|
|
return server_key_map
|
|
|
|
async def check_servers_health(self,
|
|
server_key_map: Dict[str, str]) -> List[str]:
|
|
"""Check health of servers and remove dead ones from metadata service"""
|
|
live_servers = []
|
|
dead_servers = []
|
|
|
|
try:
|
|
# Check health of each server
|
|
for key, server_url in server_key_map.items():
|
|
# First attempt - no printing errors
|
|
is_healthy = await self._check_server_health(server_url,
|
|
silent=True)
|
|
|
|
# If first attempt failed, try again before declaring server dead
|
|
if not is_healthy:
|
|
# Second attempt - will print errors if it fails
|
|
is_healthy = await self._check_server_health(server_url,
|
|
silent=False)
|
|
|
|
if not is_healthy:
|
|
# Only now add to dead servers
|
|
dead_servers.append((key, server_url))
|
|
logger.warning(
|
|
f"Server {server_url} is not healthy after retry - removing"
|
|
)
|
|
else:
|
|
live_servers.append(server_url)
|
|
else:
|
|
live_servers.append(server_url)
|
|
|
|
# Remove dead servers from etcd
|
|
for key, dead_server in dead_servers:
|
|
try:
|
|
logger.info(
|
|
f"Removing dead server {dead_server} from metadata server"
|
|
)
|
|
self._metadata_server.remove(key)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error removing dead server from metadata service: {e}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking server health: {e}")
|
|
|
|
return live_servers if live_servers else self._servers
|
|
|
|
async def _check_server_health(self, server_url, silent=False) -> bool:
|
|
"""Check if a server is healthy by querying its health endpoint"""
|
|
if not self._session:
|
|
self._session = aiohttp.ClientSession()
|
|
|
|
try:
|
|
async with self._session.get(
|
|
f"{server_url}/health",
|
|
timeout=self._health_check_timeout) as response:
|
|
if response.status != 200:
|
|
if not silent:
|
|
logger.warning(
|
|
f"Server {server_url} is not healthy (status: {response.status})"
|
|
)
|
|
return False
|
|
return True
|
|
except Exception as e:
|
|
if not silent:
|
|
logger.warning(f"Server {server_url} is not reachable: {e}")
|
|
return False
|
|
|
|
|
|
class RoundRobinRouter(Router):
|
|
|
|
def __init__(self,
|
|
server_role: ServerRole,
|
|
servers: List[str] = None,
|
|
metadata_server: JsonDictionary = None,
|
|
**kwargs):
|
|
super().__init__(server_role, servers, metadata_server)
|
|
self._server_idx = 0
|
|
|
|
def _on_servers_updated(self, old_servers, new_servers):
|
|
"""Reset the index when servers are removed to prevent index out of bounds errors."""
|
|
if len(new_servers) < len(old_servers):
|
|
# Servers were removed, reset the index
|
|
self._server_idx = 0
|
|
elif self._server_idx >= len(new_servers):
|
|
# Safety check: ensure index is always within bounds
|
|
self._server_idx = 0
|
|
|
|
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
|
if not self._servers:
|
|
if self._metadata_server:
|
|
raise ValueError(
|
|
f"No {self._server_role} servers available in metadata service"
|
|
)
|
|
else:
|
|
raise ValueError(f"No {self._server_role} servers available")
|
|
|
|
async with self._lock:
|
|
# Safety check: ensure index is within bounds
|
|
if self._server_idx >= len(self._servers):
|
|
self._server_idx = 0
|
|
|
|
server = self._servers[self._server_idx]
|
|
self._server_idx = (self._server_idx + 1) % len(self._servers)
|
|
return server, {}
|
|
|
|
async def finish_request(self, request: OpenAIRequest):
|
|
pass
|
|
|
|
|
|
class LoadBalancingRouter(Router):
|
|
|
|
def __init__(self,
|
|
server_role: ServerRole,
|
|
servers: List[str] = None,
|
|
metadata_server: JsonDictionary = None,
|
|
use_tokens: bool = False,
|
|
**kwargs):
|
|
super().__init__(server_role, servers, metadata_server)
|
|
# Load map between servers and their number of tokens processed
|
|
self._server_state = {}
|
|
self._server_load_heap = []
|
|
|
|
# Routing table to map requests to servers
|
|
self._req_routing_table = {}
|
|
|
|
self._use_tokens = use_tokens
|
|
self._init_heap()
|
|
|
|
def _on_servers_updated(self, old_servers, new_servers):
|
|
"""Rebuild the heap when the server list changes."""
|
|
# Keep the state for servers that still exist
|
|
current_state = {}
|
|
for server in new_servers:
|
|
if server in self._server_state:
|
|
# Keep existing state
|
|
current_state[server] = self._server_state[server]
|
|
else:
|
|
# Initialize new server state
|
|
current_state[server] = ServerState(server, self._use_tokens)
|
|
|
|
# Update state and rebuild heap
|
|
self._server_state = current_state
|
|
self._server_load_heap = []
|
|
for server in new_servers:
|
|
heapq.heappush(self._server_load_heap,
|
|
(self._get_server_load(server), server))
|
|
|
|
def _init_heap(self):
|
|
for server in self._servers:
|
|
self._server_state[server] = ServerState(server, self._use_tokens)
|
|
heapq.heappush(self._server_load_heap,
|
|
(self._get_server_load(server), server))
|
|
|
|
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
|
if not self._servers:
|
|
if self._metadata_server:
|
|
raise ValueError(
|
|
f"No {self._server_role} servers available in metadata service"
|
|
)
|
|
else:
|
|
raise ValueError(f"No {self._server_role} servers available")
|
|
|
|
async with self._lock:
|
|
server = heapq.heappop(self._server_load_heap)[1]
|
|
await self._server_state[server].increment_load(request)
|
|
heapq.heappush(self._server_load_heap,
|
|
(self._get_server_load(server), server))
|
|
|
|
self._req_routing_table[id(request)] = server
|
|
|
|
return server, {}
|
|
|
|
def _get_server_load(self, server):
|
|
return self._server_state[server]._num_active_tokens if self._use_tokens \
|
|
else self._server_state[server]._num_active_requests
|
|
|
|
async def finish_request(self, request: OpenAIRequest):
|
|
async with self._lock:
|
|
server = self._req_routing_table[id(request)]
|
|
await self._server_state[server].decrement_load(request)
|
|
heapq.heappush(self._server_load_heap,
|
|
(self._get_server_load(server), server))
|
|
del self._req_routing_table[id(request)]
|
|
|
|
|
|
def block_key_hasher(token_ids: list[int],
|
|
parent_hash: Optional[int] = None) -> int:
|
|
block_key = BlockKey(token_ids)
|
|
return BlockKeyHasher.hash(block_key,
|
|
0 if parent_hash is None else parent_hash)
|
|
|
|
|
|
class KvCacheAwareRouter(Router):
|
|
|
|
def __init__(self,
|
|
server_role: ServerRole = None,
|
|
servers: list[str] = None,
|
|
metadata_server: JsonDictionary = None,
|
|
use_tokens: bool = False,
|
|
max_batch_size: int = 64,
|
|
tokens_per_block: int = 32,
|
|
**kwargs):
|
|
super().__init__(server_role, servers, metadata_server)
|
|
self._lock = asyncio.Lock()
|
|
|
|
# Load map between servers and their number of tokens processed
|
|
self._server_state: dict[str, KvCacheAwareServerState] = {
|
|
server: KvCacheAwareServerState(server, use_tokens)
|
|
for server in servers or []
|
|
}
|
|
|
|
# Routing table to map requests to servers
|
|
self._req_routing_table: dict[int, OpenAIRequest] = {}
|
|
|
|
self._tokenizers = {}
|
|
# TODO: use max_num_tokens? per server?
|
|
self._max_batch_size = max_batch_size
|
|
self._tokens_per_block = tokens_per_block
|
|
|
|
def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
|
|
prompts = request.prompt
|
|
if isinstance(prompts, list) and isinstance(prompts[0], list):
|
|
return prompts
|
|
elif isinstance(prompts, list) and isinstance(prompts[0], int):
|
|
return [prompts]
|
|
elif isinstance(prompts, str):
|
|
prompts = [prompts]
|
|
else:
|
|
assert isinstance(prompts, list) and isinstance(prompts[0], str)
|
|
|
|
# TODO: send tokenize-only request instead of tokenizing locally
|
|
if request.model not in self._tokenizers:
|
|
self._tokenizers[request.model] = AutoTokenizer.from_pretrained(
|
|
request.model)
|
|
tokenizer = self._tokenizers[request.model]
|
|
return [tokenizer(prompt)["input_ids"] for prompt in prompts]
|
|
|
|
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
|
servers = list(self._server_state.keys())
|
|
token_lists = self._tokenize(request)
|
|
block_hashes: list[list[int]] = []
|
|
for token_list in token_lists:
|
|
hash_list = []
|
|
# in KvCacheManager, the last token is not included in the block key
|
|
for t in range(0, len(token_list) - 1, self._tokens_per_block):
|
|
t_end = min(t + self._tokens_per_block, len(token_list) - 1)
|
|
hash_list.append(
|
|
block_key_hasher(token_list[t:t_end],
|
|
None if t == 0 else hash_list[-1]))
|
|
block_hashes.append(hash_list)
|
|
padded_tokens = sum(
|
|
len(hash_list)
|
|
for hash_list in block_hashes) * self._tokens_per_block
|
|
# select the server by (KV match - load)
|
|
# TODO: more options
|
|
workloads = [
|
|
state.num_active_requests()
|
|
for state in self._server_state.values()
|
|
]
|
|
scores = []
|
|
matches = []
|
|
for i in range(len(servers)):
|
|
server = servers[i]
|
|
# https://github.com/ai-dynamo/dynamo/blob/main/docs/kv_cache_routing.md#kv-cache-routing-and-load-balancing
|
|
matches.append(
|
|
await self._server_state[server].matched_tokens(block_hashes))
|
|
score = matches[-1] / padded_tokens - workloads[
|
|
i] / self._max_batch_size
|
|
scores.append(score)
|
|
server = servers[scores.index(max(scores))]
|
|
await self._server_state[server].increment_load(request)
|
|
async with self._lock:
|
|
self._req_routing_table[id(request)] = server
|
|
return server, {
|
|
"block_hashes": block_hashes, # list[list[int]]
|
|
"token_lists": token_lists, # list[list[int]]
|
|
"matches": matches, # list[int]
|
|
}
|
|
|
|
async def finish_request(self,
|
|
request: OpenAIRequest,
|
|
session: Optional[aiohttp.ClientSession] = None):
|
|
async with self._lock:
|
|
server = self._req_routing_table[id(request)]
|
|
del self._req_routing_table[id(request)]
|
|
await self._server_state[server].decrement_load(request,
|
|
session=session)
|
|
|
|
|
|
def create_router(router_config: Optional[RouterConfig],
|
|
servers: List[str],
|
|
metadata_server: JsonDictionary = None) -> Router:
|
|
"""
|
|
Factory function to create different types of router instances.
|
|
|
|
Args:
|
|
router_type (str): Type of router to create. Supported values:
|
|
- "round_robin": Creates a RoundRobinRouter (default)
|
|
- "load_balancing": Creates a LoadBalancingRouter, which balances requests or tokens across instances
|
|
- "kv_cache_aware": Creates a KvCacheAwareRouter, which balances requests across instances additionally based on KV cache hits
|
|
servers: List of server URLs
|
|
|
|
Returns:
|
|
Router: An instance of the requested router type
|
|
|
|
Raises:
|
|
ValueError: If an unsupported router type is provided
|
|
"""
|
|
if router_config is None:
|
|
# Create a default router without server_role
|
|
return RoundRobinRouter(None, servers)
|
|
|
|
router_map = {
|
|
"round_robin": RoundRobinRouter,
|
|
"load_balancing": LoadBalancingRouter,
|
|
"kv_cache_aware": KvCacheAwareRouter,
|
|
}
|
|
|
|
router_type = router_config.type
|
|
router_class = router_map.get(router_type.lower())
|
|
if router_class is None:
|
|
raise ValueError(f"Unsupported router type: {router_type}. "
|
|
f"Supported types are: {list(router_map.keys())}")
|
|
|
|
# Pass server_role as the first argument
|
|
return router_class(router_config.server_role, servers, metadata_server,
|
|
**router_config.args)
|