feat: Add integration of etcd (#3738)

Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Signed-off-by: BatshevaBlack <132911331+BatshevaBlack@users.noreply.github.com>
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Batsheva Black <bblack@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: BatshevaBlack <132911331+BatshevaBlack@users.noreply.github.com>
This commit is contained in:
Shunkangz 2025-06-03 20:01:44 +08:00 committed by GitHub
parent 3fe4a1842a
commit ae9a6cf24f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1407 additions and 33 deletions

View File

@ -55,3 +55,4 @@ nvtx
matplotlib # FIXME: this is added to make nvtx happy
meson
ninja
etcd3

View File

@ -18,7 +18,9 @@ from tensorrt_llm.llmapi import (LLM, BuildConfig, CapacitySchedulerPolicy,
DynamicBatchConfig, KvCacheConfig,
SchedulerConfig)
from tensorrt_llm.llmapi.disagg_utils import (CtxGenServerConfig,
parse_disagg_config_file)
MetadataServerConfig, ServerRole,
parse_disagg_config_file,
parse_metadata_server_config_file)
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict
from tensorrt_llm.llmapi.mpi_session import find_free_port
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
@ -127,7 +129,12 @@ def get_llm_args(model: str,
return llm_args, llm_args_extra_dict
def launch_server(host: str, port: int, llm_args: dict):
def launch_server(host: str,
port: int,
llm_args: dict,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
server_role: Optional[ServerRole] = None):
backend = llm_args["backend"]
model = llm_args["model"]
@ -136,7 +143,10 @@ def launch_server(host: str, port: int, llm_args: dict):
else:
llm = LLM(**llm_args)
server = OpenAIServer(llm=llm, model=model)
server = OpenAIServer(llm=llm,
model=model,
server_role=server_role,
metadata_server_cfg=metadata_server_cfg)
asyncio.run(server(host, port))
@ -228,6 +238,16 @@ def launch_server(host: str, port: int, llm_args: dict):
default=None,
help="[Experimental] Specify the parser for reasoning models.",
)
@click.option("--metadata_server_config_file",
type=str,
default=None,
help="Path to metadata server config file")
@click.option(
"--server_role",
type=str,
default=None,
help="Server role. Specify this value only if running in disaggregated mode."
)
def serve(model: str, tokenizer: Optional[str], host: str, port: int,
log_level: str, backend: str, max_beam_width: int,
max_batch_size: int, max_num_tokens: int, max_seq_len: int,
@ -235,8 +255,9 @@ def serve(model: str, tokenizer: Optional[str], host: str, port: int,
cluster_size: Optional[int], gpus_per_node: Optional[int],
kv_cache_free_gpu_memory_fraction: float,
num_postprocess_workers: int, trust_remote_code: bool,
extra_llm_api_options: Optional[str],
reasoning_parser: Optional[str]):
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str],
server_role: Optional[str]):
"""Running an OpenAI API compatible server
MODEL: model name | HF checkpoint path | TensorRT engine path
@ -267,7 +288,16 @@ def serve(model: str, tokenizer: Optional[str], host: str, port: int,
llm_args_extra_dict = yaml.safe_load(f)
llm_args = update_llm_args_with_extra_dict(llm_args, llm_args_extra_dict)
launch_server(host, port, llm_args)
metadata_server_cfg = parse_metadata_server_config_file(
metadata_server_config_file)
if metadata_server_cfg is not None:
try:
server_role = ServerRole[server_role.upper()]
except ValueError:
raise ValueError(f"Invalid server role: {server_role}. " \
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
launch_server(host, port, llm_args, metadata_server_cfg, server_role)
def get_ctx_gen_server_urls(
@ -289,6 +319,11 @@ def get_ctx_gen_server_urls(
type=str,
default=None,
help="Specific option for disaggregated mode.")
@click.option("-m",
"--metadata_server_config_file",
type=str,
default=None,
help="Path to metadata server config file")
@click.option("-t",
"--server_start_timeout",
type=int,
@ -299,15 +334,27 @@ def get_ctx_gen_server_urls(
type=int,
default=180,
help="Request timeout")
def disaggregated(config_file: Optional[str], server_start_timeout: int,
request_timeout: int):
@click.option("-l",
'--log_level',
type=click.Choice(severity_map.keys()),
default='info',
help="The logging level.")
def disaggregated(config_file: Optional[str],
metadata_server_config_file: Optional[str],
server_start_timeout: int, request_timeout: int,
log_level: str):
"""Running server in disaggregated mode"""
logger.set_level(log_level)
disagg_cfg = parse_disagg_config_file(config_file)
ctx_server_urls, gen_server_urls = get_ctx_gen_server_urls(
disagg_cfg.server_configs)
metadata_server_cfg = parse_metadata_server_config_file(
metadata_server_config_file)
server = OpenAIDisaggServer(
ctx_servers=ctx_server_urls,
gen_servers=gen_server_urls,
@ -315,7 +362,8 @@ def disaggregated(config_file: Optional[str], server_start_timeout: int,
server_start_timeout_secs=server_start_timeout,
ctx_router_config=disagg_cfg.ctx_router_config,
gen_router_config=disagg_cfg.gen_router_config,
conditional_disagg_config=disagg_cfg.conditional_disagg_config)
conditional_disagg_config=disagg_cfg.conditional_disagg_config,
metadata_server_cfg=metadata_server_cfg)
asyncio.run(server(disagg_cfg.hostname, disagg_cfg.port))

View File

@ -1,5 +1,6 @@
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Literal, Optional, Tuple
import yaml
@ -15,6 +16,11 @@ __all__ = [
]
class ServerRole(Enum):
CONTEXT = 0
GENERATION = 1
@dataclass
class CtxGenServerConfig():
type: Literal['ctx', 'gen']
@ -28,6 +34,7 @@ class CtxGenServerConfig():
class RouterConfig():
type: str = "round_robin"
args: dict = field(default_factory=dict)
server_role: ServerRole = None
@dataclass
@ -45,6 +52,14 @@ class DisaggServerConfig():
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
@dataclass
class MetadataServerConfig():
server_type: Literal['etcd']
hostname: str = "localhost"
port: int = 2379
health_check_timeout: float = 5.0
def parse_disagg_config_file(yaml_config_file: str):
with open(yaml_config_file, 'r') as file:
@ -85,6 +100,9 @@ def extract_disagg_cfg(hostname: str = 'localhost',
ctx_router_config = extract_router_config(context_servers)
gen_router_config = extract_router_config(generation_servers)
ctx_router_config.server_role = ServerRole.CONTEXT
gen_router_config.server_role = ServerRole.GENERATION
conditional_disagg_config = ConditionalDisaggConfig(
**conditional_disagg_config) if conditional_disagg_config else None
@ -220,3 +238,14 @@ def split_world_comm(
)
return is_leader, instance_idx, sub_comm
def parse_metadata_server_config_file(
metadata_server_config_file: Optional[str]
) -> Optional[MetadataServerConfig]:
if metadata_server_config_file is None:
return None
with open(metadata_server_config_file, 'r') as file:
config = yaml.safe_load(file)
return MetadataServerConfig(**config)

View File

@ -2,7 +2,9 @@ import atexit
import json
import os
import shutil
import socket
import tempfile
import time
import weakref
from pathlib import Path
from typing import Any, List, Literal, Optional, Sequence, Union
@ -86,6 +88,7 @@ LLM_DOCSTRING = LLMARGS_EXPLICIT_DOCSTRING + """
Attributes:
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
workspace (pathlib.Path): The directory to store intermediate files.
llm_id (str): The unique ID of the LLM instance.
"""
@ -110,6 +113,7 @@ class LLM:
**kwargs: Any) -> None:
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
self._llm_id = None
try:
llm_args_cls = TorchLlmArgs if kwargs.get(
@ -186,6 +190,16 @@ class LLM:
def workspace(self) -> Path:
return Path(self._workspace.name) if self._on_trt_backend else None
@property
def llm_id(self) -> str:
if self._llm_id is None:
hostname = socket.gethostname()
pid = os.getpid()
timestamp = int(time.time() * 1000)
self._llm_id = f"{hostname}-{pid}-{timestamp}"
return self._llm_id
def generate(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],

View File

@ -0,0 +1,95 @@
import json
from abc import ABC, abstractmethod
from typing import Optional
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig
from tensorrt_llm.logger import logger
try:
import etcd3
except Exception as e:
logger.warning(f"etcd3 is not installed correctly: {e}")
class RemoteDictionary(ABC):
@abstractmethod
def get(self, key: str) -> str:
pass
@abstractmethod
def put(self, key: str, value: str):
pass
@abstractmethod
def remove(self, key: str):
pass
@abstractmethod
def keys(self) -> list[str]:
pass
class EtcdDictionary(RemoteDictionary):
def __init__(self, host: str, port: int):
self._client = etcd3.client(host, port)
def get(self, key: str) -> str:
return self._client.get(key)
def put(self, key: str, value: str):
self._client.put(key, value)
def remove(self, key: str):
self._client.delete(key)
def keys(self) -> list[str]:
# TODO: Confirm the final save key format
# This implementation assumes that key is in the
# format of "trtllm/executor_name/key"
unique_keys = set()
for _, metadata in self._client.get_all():
key = metadata.key.decode('utf-8')
sub_keys = key.split('/')
if len(sub_keys) >= 2:
top_prefix = "/".join(sub_keys[:2])
unique_keys.add(top_prefix)
return list(unique_keys)
class JsonDictionary:
def __init__(self, dict: RemoteDictionary):
self._dict = dict
def get(self, key: str) -> str:
bytes_data, _ = self._dict.get(key)
return json.loads(bytes_data.decode('utf-8'))
def put(self, key: str, value: str):
self._dict.put(key, json.dumps(value))
def remove(self, key: str):
self._dict.remove(key)
def keys(self) -> list[str]:
return self._dict.keys()
def create_metadata_server(
metadata_server_cfg: Optional[MetadataServerConfig]
) -> Optional[JsonDictionary]:
if metadata_server_cfg is None:
return None
if metadata_server_cfg.server_type == 'etcd':
dict = EtcdDictionary(host=metadata_server_cfg.hostname,
port=metadata_server_cfg.port)
else:
raise ValueError(
f"Unsupported metadata server type: {metadata_server_cfg.server_type}"
)
return JsonDictionary(dict)

View File

@ -18,7 +18,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
# yapf: disable
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi.disagg_utils import (ConditionalDisaggConfig,
MetadataServerConfig,
RouterConfig)
from tensorrt_llm.serve.metadata_server import create_metadata_server
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
@ -42,14 +44,17 @@ class OpenAIDisaggServer:
server_start_timeout_secs: int = 180,
ctx_router_config: Optional[RouterConfig] = None,
gen_router_config: Optional[RouterConfig] = None,
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None):
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None,
metadata_server_cfg: MetadataServerConfig = None):
self.ctx_servers = ctx_servers
self.gen_servers = gen_servers
self.ctx_router = create_router(ctx_router_config, ctx_servers)
self.gen_router = create_router(gen_router_config, gen_servers)
self.metadata_server = create_metadata_server(metadata_server_cfg)
self.ctx_router = create_router(ctx_router_config, ctx_servers, self.metadata_server)
self.gen_router = create_router(gen_router_config, gen_servers, self.metadata_server)
self.conditional_disagg_config = conditional_disagg_config
if (len(self.gen_servers) == 0):
raise ValueError("At least one generation server must be provided")
@ -71,7 +76,19 @@ class OpenAIDisaggServer:
logging.info("Waiting for context and generation servers to be ready")
await self.wait_for_servers_ready(server_start_timeout_secs)
if self.metadata_server:
logging.info("Starting server monitoring via metadata service")
await self.ctx_router.start_server_monitoring()
await self.gen_router.start_server_monitoring()
yield
if self.metadata_server:
logging.info("Stopping server monitoring via metadata service")
await self.ctx_router.stop_server_monitoring()
await self.gen_router.stop_server_monitoring()
await self.session.close() # Ensure session cleanup
self.app = FastAPI(lifespan=lifespan)

View File

@ -3,6 +3,7 @@ import asyncio
import signal
import traceback
from contextlib import asynccontextmanager
from datetime import datetime
from http import HTTPStatus
from pathlib import Path
from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple
@ -18,12 +19,14 @@ from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.executor.postproc_worker import PostprocParams
from tensorrt_llm.inputs import prompt_inputs
from tensorrt_llm.llmapi import LLM
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.chat_utils import (ConversationMessage,
apply_chat_template,
check_multiple_response,
parse_chat_messages_coroutines)
from tensorrt_llm.serve.metadata_server import create_metadata_server
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
@ -48,9 +51,14 @@ class OpenAIServer:
def __init__(self,
llm: LLM,
model: str):
model: str,
server_role: Optional[ServerRole],
metadata_server_cfg: MetadataServerConfig):
self.llm = llm
self.tokenizer = llm.tokenizer
self.metadata_server = create_metadata_server(metadata_server_cfg)
self.server_role = server_role
self.binding_addr = None # Will be set in __call__
try:
hf_tokenizer_path = llm._hf_model_dir or self.tokenizer.tokenizer.name_or_path
self.processor = AutoProcessor.from_pretrained(hf_tokenizer_path)
@ -68,8 +76,25 @@ class OpenAIServer:
@asynccontextmanager
async def lifespan(app: FastAPI):
if self.metadata_server is not None:
metadata = {
"model": self.model,
"version": VERSION,
"timestamp": datetime.now().isoformat(),
"server_role": server_role.name,
"url": self.binding_addr
}
# TODO: add more metadata
# Register with ETCD using the existing key format
self.metadata_server.put(f"trtllm/{self.llm.llm_id}", metadata)
logger.info(f"trtllm/{self.llm.llm_id} is registered")
# terminate rank0 worker
yield
if self.metadata_server is not None:
self.metadata_server.remove(f"trtllm/{self.llm.llm_id}")
logger.info(f"trtllm/{self.llm.llm_id} is unregistered")
self.llm.shutdown()
self.app = FastAPI(lifespan=lifespan)
@ -403,11 +428,12 @@ class OpenAIServer:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
print(f"Encountered an exception: {str(e)}")
traceback.print_exc()
return self.create_error_response(str(e))
async def __call__(self, host, port):
# Store the binding address for server registration
self.binding_addr = f"http://{host}:{port}"
config = uvicorn.Config(self.app,
host=host,
port=port,

View File

@ -1,14 +1,16 @@
import asyncio
import heapq
from abc import ABC, abstractmethod
from typing import Iterable, Optional, Union
from typing import Dict, Iterable, List, Optional, Union
import aiohttp
from transformers import AutoTokenizer
from tensorrt_llm.bindings.internal.batch_manager import (BlockKey,
BlockKeyHasher)
from tensorrt_llm.llmapi.disagg_utils import RouterConfig
from tensorrt_llm.llmapi.disagg_utils import RouterConfig, ServerRole
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.metadata_server import JsonDictionary
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
CompletionRequest)
@ -62,6 +64,13 @@ class ServerState:
self._num_active_requests -= 1
self._num_active_tokens -= num_tokens
async def is_healthy(self) -> bool:
try:
async with self._session.get(self._server + "/health") as response:
return response.status == 200
except Exception:
return False
class KvCacheAwareServerState(ServerState):
@ -135,8 +144,24 @@ class KvCacheAwareServerState(ServerState):
class Router(ABC):
def __init__(self, servers: list[str] = None):
self._servers = servers
def __init__(self,
server_role: ServerRole,
servers: List[str] = None,
metadata_server: JsonDictionary = None):
self._servers = servers or []
self._metadata_server = metadata_server
self._server_role = server_role
self._lock = asyncio.Lock()
self._monitor_task = None
self._session = None
self._health_check_timeout = 5.0 # Default timeout in seconds
def _on_servers_updated(self, old_servers, new_servers):
"""Called when the server list changes. Override in subclasses to handle index resets.
Args:
old_servers: The previous server list
new_servers: The new server list
"""
@abstractmethod
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
@ -146,16 +171,276 @@ class Router(ABC):
async def finish_request(self, request: OpenAIRequest):
pass
async def start_server_monitoring(self, poll_interval: int = 10):
"""Start monitoring servers update from metadata service"""
if not self._metadata_server:
logger.info(
"No metadata server configured, skipping server monitoring")
return
# Create a session for health checks if it doesn't exist
if not self._session:
self._session = aiohttp.ClientSession()
logger.info(
f"Starting server monitoring for {self._server_role} servers")
self._monitor_task = asyncio.create_task(
self._monitor_servers(poll_interval))
async def stop_server_monitoring(self):
"""Stop monitoring servers update from metadata service"""
if self._monitor_task:
logger.info(
f"Stopping server monitoring for {self._server_role} servers")
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
self._monitor_task = None
# Close session when stopping monitoring
await self.close_session()
async def close_session(self):
if self._session:
try:
await self._session.close()
self._session = None
logger.debug("HTTP session closed")
except Exception as e:
logger.error(f"Error closing session: {e}")
self._session = None
async def _monitor_servers(self, poll_interval: int = 10):
while True:
try:
if self._metadata_server:
# Get servers from metadata
server_key_map = await self.fetch_live_servers()
# Check health and get live servers
live_servers = await self.check_servers_health(
server_key_map)
# Filter by server role if needed
role_specific_servers = self._filter_servers_by_role(
live_servers, server_key_map)
# Use filtered servers if available
final_servers = role_specific_servers if role_specific_servers else []
# Update server list
async with self._lock:
if final_servers != self._servers:
num_old_servers = len(self._servers)
old_servers = self._servers.copy()
self._servers = final_servers
num_new_servers = len(self._servers)
# Call handler for server list changes
self._on_servers_updated(old_servers, self._servers)
logger.info(
f"Updated {self._server_role} server list: {num_old_servers} -> {num_new_servers} servers"
)
if logger.level == "debug" and self._servers:
for server in self._servers:
logger.debug(f" - {server}")
else:
logger.debug(
f"No change in {self._server_role} server list: {len(self._servers)} servers"
)
except Exception as e:
logger.error(f"Error in server monitoring: {e}")
# Wait before next poll
await asyncio.sleep(poll_interval)
def _filter_servers_by_role(self, servers, server_key_map):
"""Filter servers by role (context or generation)"""
if not self._metadata_server or not servers:
return []
filtered_servers = []
# Invert to get {url: key} for lookup
url_to_key = {url: key for key, url in server_key_map.items()}
for server_url in servers:
key = url_to_key.get(server_url)
if key:
server_metadata = self._metadata_server.get(key)
if server_metadata:
# Use either server_type or server_role field
server_type = server_metadata.get('server_type', '').lower()
if not server_type:
server_type = server_metadata.get('server_role',
'').lower()
# Extract port for visibility
parts = server_url.split(':')
if len(parts) >= 3:
parts[2]
# Check if server type matches our role
if (self._server_role == ServerRole.CONTEXT and server_type == 'context') or \
(self._server_role == ServerRole.GENERATION and server_type == 'generation'):
filtered_servers.append(server_url)
return filtered_servers
async def fetch_live_servers(self) -> Dict[str, str]:
"""Fetch all servers from metadata service and return {key: url} mapping"""
if not self._metadata_server:
# Only use static servers if no metadata server
return {server: "" for server in self._servers}
# If metadata server is available, ignore static server list entirely
server_key_map = {}
try:
# Get all keys from the metadata server
all_keys = self._metadata_server.keys()
logger.debug(f"Found {len(all_keys)} keys in metadata server")
# Filter keys that start with 'trtllm/' and extract server metadata
matching_keys = 0
for key in all_keys:
if key.startswith('trtllm/'):
matching_keys += 1
server_metadata = self._metadata_server.get(key)
if server_metadata and isinstance(
server_metadata, dict) and 'url' in server_metadata:
server_key_map[key] = server_metadata['url']
# Check if metadata includes health check timeout
if 'health_check_timeout' in server_metadata:
try:
self._health_check_timeout = float(
server_metadata['health_check_timeout'])
logger.debug(
f"Using health check timeout: {self._health_check_timeout}s"
)
except (ValueError, TypeError):
logger.warning(
f"Invalid health_check_timeout value: {server_metadata['health_check_timeout']}"
)
if server_key_map:
logger.info(
f"Using {len(server_key_map)} servers from metadata service"
)
else:
logger.warning("No servers found in metadata service")
except Exception as e:
logger.error(f"Error fetching servers from metadata service: {e}")
return server_key_map
async def check_servers_health(self,
server_key_map: Dict[str, str]) -> List[str]:
"""Check health of servers and remove dead ones from metadata service"""
live_servers = []
dead_servers = []
try:
# Check health of each server
for key, server_url in server_key_map.items():
# First attempt - no printing errors
is_healthy = await self._check_server_health(server_url,
silent=True)
# If first attempt failed, try again before declaring server dead
if not is_healthy:
# Second attempt - will print errors if it fails
is_healthy = await self._check_server_health(server_url,
silent=False)
if not is_healthy:
# Only now add to dead servers
dead_servers.append((key, server_url))
logger.warning(
f"Server {server_url} is not healthy after retry - removing"
)
else:
live_servers.append(server_url)
else:
live_servers.append(server_url)
# Remove dead servers from etcd
for key, dead_server in dead_servers:
try:
logger.info(
f"Removing dead server {dead_server} from metadata server"
)
self._metadata_server.remove(key)
except Exception as e:
logger.error(
f"Error removing dead server from metadata service: {e}"
)
except Exception as e:
logger.error(f"Error checking server health: {e}")
return live_servers if live_servers else self._servers
async def _check_server_health(self, server_url, silent=False) -> bool:
"""Check if a server is healthy by querying its health endpoint"""
if not self._session:
self._session = aiohttp.ClientSession()
try:
async with self._session.get(
f"{server_url}/health",
timeout=self._health_check_timeout) as response:
if response.status != 200:
if not silent:
logger.warning(
f"Server {server_url} is not healthy (status: {response.status})"
)
return False
return True
except Exception as e:
if not silent:
logger.warning(f"Server {server_url} is not reachable: {e}")
return False
class RoundRobinRouter(Router):
def __init__(self, servers: list[str] = None, **kwargs):
super().__init__(servers)
def __init__(self,
server_role: ServerRole,
servers: List[str] = None,
metadata_server: JsonDictionary = None,
**kwargs):
super().__init__(server_role, servers, metadata_server)
self._server_idx = 0
def _on_servers_updated(self, old_servers, new_servers):
"""Reset the index when servers are removed to prevent index out of bounds errors."""
if len(new_servers) < len(old_servers):
# Servers were removed, reset the index
self._server_idx = 0
elif self._server_idx >= len(new_servers):
# Safety check: ensure index is always within bounds
self._server_idx = 0
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
server = self._servers[self._server_idx]
self._server_idx = (self._server_idx + 1) % len(self._servers)
if not self._servers:
if self._metadata_server:
raise ValueError(
f"No {self._server_role} servers available in metadata service"
)
else:
raise ValueError(f"No {self._server_role} servers available")
async with self._lock:
# Safety check: ensure index is within bounds
if self._server_idx >= len(self._servers):
self._server_idx = 0
server = self._servers[self._server_idx]
self._server_idx = (self._server_idx + 1) % len(self._servers)
return server, {}
async def finish_request(self, request: OpenAIRequest):
@ -165,11 +450,12 @@ class RoundRobinRouter(Router):
class LoadBalancingRouter(Router):
def __init__(self,
servers: list[str] = None,
server_role: ServerRole,
servers: List[str] = None,
metadata_server: JsonDictionary = None,
use_tokens: bool = False,
**kwargs):
super().__init__(servers)
self._lock = asyncio.Lock()
super().__init__(server_role, servers, metadata_server)
# Load map between servers and their number of tokens processed
self._server_state = {}
self._server_load_heap = []
@ -180,6 +466,25 @@ class LoadBalancingRouter(Router):
self._use_tokens = use_tokens
self._init_heap()
def _on_servers_updated(self, old_servers, new_servers):
"""Rebuild the heap when the server list changes."""
# Keep the state for servers that still exist
current_state = {}
for server in new_servers:
if server in self._server_state:
# Keep existing state
current_state[server] = self._server_state[server]
else:
# Initialize new server state
current_state[server] = ServerState(server, self._use_tokens)
# Update state and rebuild heap
self._server_state = current_state
self._server_load_heap = []
for server in new_servers:
heapq.heappush(self._server_load_heap,
(self._get_server_load(server), server))
def _init_heap(self):
for server in self._servers:
self._server_state[server] = ServerState(server, self._use_tokens)
@ -187,6 +492,14 @@ class LoadBalancingRouter(Router):
(self._get_server_load(server), server))
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
if not self._servers:
if self._metadata_server:
raise ValueError(
f"No {self._server_role} servers available in metadata service"
)
else:
raise ValueError(f"No {self._server_role} servers available")
async with self._lock:
server = heapq.heappop(self._server_load_heap)[1]
await self._server_state[server].increment_load(request)
@ -220,18 +533,20 @@ def block_key_hasher(token_ids: list[int],
class KvCacheAwareRouter(Router):
def __init__(self,
server_role: ServerRole = None,
servers: list[str] = None,
metadata_server: JsonDictionary = None,
use_tokens: bool = False,
max_batch_size: int = 64,
tokens_per_block: int = 32,
**kwargs):
super().__init__(servers)
super().__init__(server_role, servers, metadata_server)
self._lock = asyncio.Lock()
# Load map between servers and their number of tokens processed
self._server_state: dict[str, KvCacheAwareServerState] = {
server: KvCacheAwareServerState(server, use_tokens)
for server in servers
for server in servers or []
}
# Routing table to map requests to servers
@ -313,7 +628,8 @@ class KvCacheAwareRouter(Router):
def create_router(router_config: Optional[RouterConfig],
servers: list[str]) -> Router:
servers: List[str],
metadata_server: JsonDictionary = None) -> Router:
"""
Factory function to create different types of router instances.
@ -331,7 +647,8 @@ def create_router(router_config: Optional[RouterConfig],
ValueError: If an unsupported router type is provided
"""
if router_config is None:
return RoundRobinRouter(servers)
# Create a default router without server_role
return RoundRobinRouter(None, servers)
router_map = {
"round_robin": RoundRobinRouter,
@ -345,4 +662,6 @@ def create_router(router_config: Optional[RouterConfig],
raise ValueError(f"Unsupported router type: {router_type}. "
f"Supported types are: {list(router_map.keys())}")
return router_class(servers, **router_config.args)
# Pass server_role as the first argument
return router_class(router_config.server_role, servers, metadata_server,
**router_config.args)

View File

@ -0,0 +1,426 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import signal
import subprocess
import time
import pytest
import requests
from tensorrt_llm.logger import logger
# Configuration file paths
EXAMPLES_DIR = "examples/disaggregated"
CLIENTS_DIR = f"{EXAMPLES_DIR}/clients"
CONTEXT_CONFIG_FILE = f"{EXAMPLES_DIR}/context_extra-llm-api-config.yml"
GENERATION_CONFIG_FILE = f"{EXAMPLES_DIR}/gen_extra-llm-api-config.yml"
ETCD_CONFIG_FILE = f"{EXAMPLES_DIR}/etcd_config.yaml"
DISAGG_CONFIG_FILE = f"{EXAMPLES_DIR}/disagg_config.yaml"
CLIENT_SCRIPT_FILE = f"{CLIENTS_DIR}/disagg_client.py"
PROMPTS_FILE = f"{CLIENTS_DIR}/prompts.json"
def kill_automated_disaggregated_processes():
"""Kill any existing automated disaggregated processes."""
try:
subprocess.run(['pkill', '-9', '-f', 'trtllm-serve'], check=False)
except Exception:
pass
def cleanup_automated_output_files():
"""Clean up output files from previous runs."""
for file in [
'output.json', 'output_streaming.json', 'output_workers.log',
'output_disagg.log'
]:
try:
os.remove(file)
except FileNotFoundError:
pass
def start_context_server(config,
gpu_id: int,
port: int,
env=None) -> subprocess.Popen:
"""Start a context server on specified GPU and port."""
cmd = [
"trtllm-serve", config['model_path'], "--host", "localhost", "--port",
str(port), "--backend", "pytorch", "--extra_llm_api_options",
f"./{CONTEXT_CONFIG_FILE}", "--metadata_server_config_file",
ETCD_CONFIG_FILE, "--server_role", "CONTEXT"
]
server_env = env.copy() if env else os.environ.copy()
server_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
server_env["TRTLLM_USE_UCX_KVCACHE"] = "1"
logger.info(f"Starting CONTEXT server on GPU {gpu_id} (port {port})...")
process = subprocess.Popen(cmd,
env=server_env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1)
return process
def start_generation_server(config,
gpu_id: int,
port: int,
env=None) -> subprocess.Popen:
"""Start a generation server on specified GPU and port."""
cmd = [
"trtllm-serve", config['model_path'], "--host", "localhost", "--port",
str(port), "--backend", "pytorch", "--extra_llm_api_options",
f"./{GENERATION_CONFIG_FILE}", "--metadata_server_config_file",
ETCD_CONFIG_FILE, "--server_role", "GENERATION"
]
server_env = env.copy() if env else os.environ.copy()
server_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
server_env["TRTLLM_USE_UCX_KVCACHE"] = "1"
logger.info(f"Starting GENERATION server on GPU {gpu_id} (port {port})...")
process = subprocess.Popen(cmd,
env=server_env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1)
return process
def start_disaggregated_service(config, env=None) -> subprocess.Popen:
"""Launch the disaggregated service."""
cmd = [
"trtllm-serve", "disaggregated", "-c", DISAGG_CONFIG_FILE, "-m",
ETCD_CONFIG_FILE
]
logger.info("Launching disaggregated service...")
process = subprocess.Popen(cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1)
return process
def wait_for_server_health(port: int, timeout: int = 120) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
url = f"http://localhost:{port}/health"
start_time = time.time()
logger.info(f"Waiting for server on port {port} to be healthy...")
while time.time() - start_time < timeout:
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
logger.info(f"Server on port {port} is healthy")
return True
except requests.RequestException:
pass
time.sleep(2)
logger.error(f"Timed out waiting for server on port {port}")
return False
def run_client_test(config, env=None) -> bool:
"""Run the disaggregated client test."""
cmd = [
"python3", f"./{CLIENT_SCRIPT_FILE}", "-c", DISAGG_CONFIG_FILE, "-p",
f"./{PROMPTS_FILE}"
]
logger.info("Running disaggregated client test...")
result = subprocess.run(cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True)
if result.returncode == 0:
logger.info("Client test succeeded")
logger.info(f"Client output: {result.stdout}")
return True
else:
logger.error(f"Client test failed with return code {result.returncode}")
logger.error(f"Error output: {result.stderr}")
logger.error(f"Standard output: {result.stdout}")
return False
def kill_server_by_port(port: int) -> bool:
"""Find and kill a process by port using lsof."""
try:
# Find PID using port
cmd = ["lsof", "-t", f"-i:{port}"]
result = subprocess.run(cmd, stdout=subprocess.PIPE, text=True)
if result.stdout.strip():
pid = int(result.stdout.strip())
os.kill(pid, signal.SIGKILL)
logger.info(f"Killed process {pid} on port {port}")
return True
else:
logger.warning(f"No process found on port {port}")
return False
except Exception as e:
logger.error(f"Error killing process on port {port}: {e}")
return False
def cleanup_processes(processes):
"""Kill all started processes."""
logger.info("Cleaning up all processes...")
for name, process in processes.items():
if process.poll() is None: # Still running
logger.info(f"Terminating {name} (PID: {process.pid})")
try:
process.terminate()
process.wait(timeout=5)
except (subprocess.TimeoutExpired, ProcessLookupError):
logger.warning(f"Force killing {name} (PID: {process.pid})")
try:
process.kill()
except ProcessLookupError:
pass
def start_etcd_server(working_dir, env=None) -> subprocess.Popen:
"""Start etcd server."""
cmd = ["etcd"]
logger.info("Starting etcd server...")
process = subprocess.Popen(cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
cwd=working_dir)
return process
def cleanup_etcd_data(env=None):
"""Clean up etcd data using etcdctl."""
cmd = ["etcdctl", "del", "--prefix", "trtllm/"]
logger.info("Cleaning etcd data...")
result = subprocess.run(cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True)
if result.returncode == 0:
logger.info("Successfully cleaned etcd data")
else:
logger.warning(f"Failed to clean etcd data: {result.stderr}")
def create_config_files(config):
"""Create necessary configuration files"""
# Create context config file
context_config_content = """pytorch_backend_config:
disable_overlap_scheduler: True
cache_transceiver_config:
max_num_tokens: 2048"""
with open(CONTEXT_CONFIG_FILE, 'w') as file:
file.write(context_config_content)
# Create generation config file
generation_config_content = """cache_transceiver_config:
max_num_tokens: 2048"""
with open(GENERATION_CONFIG_FILE, 'w') as file:
file.write(generation_config_content)
# Create etcd config file
etcd_config_content = """server_type: "etcd"
hostname: "localhost"
port: 2379
health_check_timeout: 5.0"""
with open(ETCD_CONFIG_FILE, 'w') as file:
file.write(etcd_config_content)
disagg_config_content = """hostname: localhost
port: 8000
backend: pytorch
context_servers:
num_instances: 1
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
urls:
- "localhost:8002"
"""
with open(DISAGG_CONFIG_FILE, 'w') as file:
file.write(disagg_config_content)
return True
def run_automated_disaggregated_test(example_dir, env=None, cwd=None):
"""Run automated disaggregated test with given configuration."""
kill_automated_disaggregated_processes()
cleanup_automated_output_files()
config = {"model_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
# Create configuration files
create_config_files(config)
processes = {}
try:
# Step 1: Start etcd server
logger.info("Step 1: Starting etcd server...")
processes["etcd"] = start_etcd_server(".", env=env)
time.sleep(5) # Give etcd time to start
# Step 2: Clean etcd data
logger.info("Step 2: Cleaning etcd data...")
cleanup_etcd_data(env=env)
# Step 3: Start context server on GPU 0 (port 8001)
logger.info("Step 3: Starting context server on GPU 0 (port 8001)...")
processes["context_8001"] = start_context_server(config,
gpu_id=0,
port=8001,
env=env)
# Step 4: Start generation server on GPU 1 (port 8002)
logger.info(
"Step 4: Starting generation server on GPU 1 (port 8002)...")
processes["generation_8002"] = start_generation_server(config,
gpu_id=1,
port=8002,
env=env)
# Step 5: Wait till gen and context ready
logger.info(
"Step 5: Waiting for context and generation servers to be ready...")
if not wait_for_server_health(port=8001):
logger.error("Context server on port 8001 failed to start")
return False
if not wait_for_server_health(port=8002):
logger.error("Generation server on port 8002 failed to start")
return False
# Step 6: Start disaggregated service
logger.info("Step 6: Starting disaggregated service...")
processes["disagg_service"] = start_disaggregated_service(config,
env=env)
# Step 7: Wait for disaggregated service and run first client test
logger.info(
"Step 7: Waiting for disaggregated service and running first client test..."
)
if not wait_for_server_health(port=8000):
logger.error("Disaggregated service failed to start")
return False
first_test_success = run_client_test(config, env=env)
if not first_test_success:
logger.error("First client test failed")
return False
# Step 8: Start second context server on GPU 2 (port 8003)
logger.info(
"Step 8: Starting second context server on GPU 2 (port 8003)...")
processes["context_8003"] = start_context_server(config,
gpu_id=2,
port=8003,
env=env)
# Step 9: Wait till ready and then 10 seconds, run second client test
logger.info(
"Step 9: Waiting for second context server and running second client test..."
)
if not wait_for_server_health(port=8003):
logger.error("Second context server on port 8003 failed to start")
return False
logger.info("Waiting additional 10 seconds for system stabilization...")
time.sleep(10)
second_test_success = run_client_test(config, env=env)
if not second_test_success:
logger.error("Second client test failed")
return False
# Step 10: Kill 8001 process (first context server)
logger.info("Step 10: Killing first context server (port 8001)...")
if "context_8001" in processes:
process = processes["context_8001"]
if process.poll() is None:
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
process.kill()
kill_server_by_port(8001)
# Step 11: Wait a few seconds and run final client test
logger.info(
"Step 11: Waiting a few seconds and running final client test...")
time.sleep(5)
final_test_success = run_client_test(config, env=env)
if not final_test_success:
logger.error("Final client test failed")
return False
logger.info("✅ All automated disaggregated tests passed successfully!")
return True
except Exception as e:
logger.exception(f"Error during automated test: {e}")
return False
finally:
cleanup_processes(processes)
kill_automated_disaggregated_processes()
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_automated_disaggregated_complete(disaggregated_test_root,
disaggregated_example_root, llm_venv,
llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
success = run_automated_disaggregated_test(
disaggregated_example_root,
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())
assert success, "Automated disaggregated test failed"

View File

@ -16,7 +16,7 @@ from tensorrt_llm import logger
from tensorrt_llm.serve.openai_protocol import (CompletionRequest,
DisaggregatedParams)
from tensorrt_llm.serve.router import (KvCacheAwareRouter,
KvCacheAwareServerState,
KvCacheAwareServerState, ServerRole,
block_key_hasher)
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@ -350,8 +350,10 @@ class KvCacheAwareRouterTester(BasicWorkerTester):
server_start_timeout_secs: int = 180):
super().__init__(ctx_servers, gen_servers, req_timeout_secs,
server_start_timeout_secs)
self.ctx_router = KvCacheAwareRouter(ctx_servers)
self.gen_router = KvCacheAwareRouter(gen_servers)
self.ctx_router = KvCacheAwareRouter(server_role=ServerRole.CONTEXT,
servers=ctx_servers)
self.gen_router = KvCacheAwareRouter(server_role=ServerRole.GENERATION,
servers=gen_servers)
async def multi_round_request(self,
session: aiohttp.ClientSession,

View File

@ -140,3 +140,6 @@ properties:
workspace:
annotation: pathlib.Path
default: inspect._empty
llm_id:
annotation: str
default: inspect._empty

View File

@ -0,0 +1,113 @@
import unittest
import etcd3
from tensorrt_llm.serve.metadata_server import EtcdDictionary
class TestEtcdDictionary(unittest.TestCase):
def setUp(self):
# Setup etcd connection parameters
self.host = "localhost"
self.port = 2379
# Create a clean etcd client for test setup/teardown
self.cleanup_client = etcd3.client(host=self.host, port=self.port)
# Create the dictionary under test
self.etcd_dict = EtcdDictionary(host=self.host, port=self.port)
# Clean up any existing test keys before each test
self._cleanup_test_keys()
def tearDown(self):
# Clean up test keys after each test
self._cleanup_test_keys()
def _cleanup_test_keys(self):
# Helper method to remove test keys
test_keys = [
"trtllm/1/test_key1", "trtllm/1/test_key2", "trtllm/2/test_key3"
]
for key in test_keys:
self.cleanup_client.delete(key)
def test_put_and_get(self):
# Test putting and getting a value
test_key = "trtllm/1/test_key1"
test_value = "value1"
# Put the value
self.etcd_dict.put(test_key, test_value)
# Get the value
value, _ = self.etcd_dict.get(test_key)
# Assert
self.assertEqual(value.decode('utf-8'), test_value)
def test_remove(self):
# Test removing a key
test_key = "trtllm/1/test_key2"
test_value = "value2"
# Setup: Put a value first
self.etcd_dict.put(test_key, test_value)
# Remove the key
self.etcd_dict.remove(test_key)
# Verify key is removed by trying to get it
result = self.cleanup_client.get(test_key)
self.assertIsNone(
result[0]) # etcd3 returns (None, None) when key doesn't exist
def test_keys(self):
# Test listing all keys
test_data = {
"trtllm/1/test_key1": "value1",
"trtllm/1/test_key2": "value2",
"trtllm/2/test_key3": "value3"
}
prefix_data = {"trtllm/1": "value1", "trtllm/2": "value2"}
# Setup: Put multiple values
for key, value in test_data.items():
self.etcd_dict.put(key, value)
# Get all keys
keys = self.etcd_dict.keys()
# Assert all test keys are present
prefix_keys = set(prefix_data.keys())
extract_keys = set(keys)
self.assertEqual(prefix_keys, extract_keys)
def test_get_nonexistent_key(self):
# Test getting a key that doesn't exist
result, _ = self.etcd_dict.get("nonexistent_key")
self.assertIsNone(result)
def test_put_update_existing(self):
# Test updating an existing key
test_key = "trtllm/1/test_key1"
initial_value = "initial_value"
updated_value = "updated_value"
# Put initial value
self.etcd_dict.put(test_key, initial_value)
# Update value
self.etcd_dict.put(test_key, updated_value)
# Get updated value
value, _ = self.etcd_dict.get(test_key)
# Assert
self.assertEqual(value.decode('utf-8'), updated_value)
if __name__ == '__main__':
unittest.main()

View File

@ -1,8 +1,13 @@
import asyncio
import copy
import threading
import time
from unittest import mock
import pytest
from tensorrt_llm.llmapi.disagg_utils import RouterConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
CompletionRequest,
DisaggregatedParams)
@ -10,6 +15,40 @@ from tensorrt_llm.serve.router import (KvCacheAwareRouter, LoadBalancingRouter,
RoundRobinRouter, create_router)
# Mock class for metadata server
class MockMetadataServer:
"""Mock metadata server for testing router interactions"""
def __init__(self):
self.servers = {}
self.lock = threading.Lock()
def get(self, key):
with self.lock:
return self.servers.get(key)
def put(self, key, value):
with self.lock:
self.servers[key] = value
return True
def remove(self, key):
with self.lock:
if key in self.servers:
del self.servers[key]
return True
return False
def add_server(self, key, url):
with self.lock:
self.servers[key] = url
return True
def keys(self, prefix=""):
with self.lock:
return [k for k in self.servers.keys() if k.startswith(prefix)]
@pytest.fixture
def servers():
return ["server1", "server2", "server3"]
@ -301,3 +340,245 @@ def test_create_router(servers):
with pytest.raises(ValueError):
create_router(RouterConfig(type="unsupported_router"), servers)
@pytest.fixture
def mock_metadata_server():
return MockMetadataServer()
@pytest.mark.slow
def test_fetch_live_servers_context(mock_metadata_server):
"""Test fetching live context servers"""
# Create router with mock metadata server
router = RoundRobinRouter(server_role="context",
metadata_servers=[mock_metadata_server])
# Start server monitoring with a shorter poll interval for testing
# but still long enough to verify the actual behavior
poll_interval = 10 # seconds
asyncio.run(router.start_server_monitoring(poll_interval=poll_interval))
try:
# Initial check - should be no servers
servers = router.fetch_live_servers()
assert len(servers) == 0, "Should have no servers initially"
# Add a server
server_key = "servers/context/server1"
server_url = "http://localhost:8001"
mock_metadata_server.add_server(server_key, {"url": server_url})
# Wait for the polling interval to pass (add 50% buffer)
wait_time = poll_interval * 1.5
logger.info(f"Waiting {wait_time} seconds for server to be detected...")
time.sleep(wait_time)
# Fetch servers again
servers = router.fetch_live_servers()
assert len(
servers) == 1, "Should have one server after adding and waiting"
assert servers[
0] == server_url, "Server URL should match what was added"
# Add another server
server_key2 = "servers/context/server2"
server_url2 = "http://localhost:8002"
mock_metadata_server.add_server(server_key2, {"url": server_url2})
# Wait for the polling interval again
logger.info(
f"Waiting {wait_time} seconds for second server to be detected...")
time.sleep(wait_time)
# Fetch servers again
servers = router.fetch_live_servers()
assert len(
servers
) == 2, "Should have two servers after adding second one and waiting"
assert server_url in servers, "First server should still be present"
assert server_url2 in servers, "Second server should be present"
# Remove a server
mock_metadata_server.remove(server_key)
# Wait for the polling interval again
logger.info(
f"Waiting {wait_time} seconds for server removal to be detected...")
time.sleep(wait_time)
# Fetch servers again
servers = router.fetch_live_servers()
assert len(
servers
) == 1, "Should have one server after removing one and waiting"
assert servers[
0] == server_url2, "Remaining server should be the second one"
finally:
# Clean up
asyncio.run(router.stop_server_monitoring())
@pytest.mark.slow
def test_fetch_live_servers_with_delay(mock_metadata_server):
"""Test fetching live servers with the actual polling delay"""
# Create router with mock metadata server
poll_interval = 5 # seconds
router = RoundRobinRouter(server_role="context",
metadata_servers=[mock_metadata_server])
# Start server monitoring with shorter interval for testing
asyncio.run(router.start_server_monitoring(poll_interval=poll_interval))
try:
# Initial check - should be no servers
servers = router.fetch_live_servers()
assert len(servers) == 0, "Should have no servers initially"
# Add a server
server_key = "servers/context/server1"
server_url = "http://localhost:8001"
mock_metadata_server.add_server(server_key, {"url": server_url})
# Wait for a bit less than the polling interval - should still have no servers
short_wait = poll_interval * 0.4
logger.info(
f"Waiting {short_wait} seconds (less than polling interval)...")
time.sleep(short_wait)
# Verify server isn't discovered yet
servers = router.fetch_live_servers()
assert len(
servers
) == 0, "Should still have no servers before polling interval completes"
# Wait for the polling interval to pass
remaining_wait = poll_interval * 1.2
logger.info(
f"Waiting additional {remaining_wait} seconds for server to be detected..."
)
time.sleep(remaining_wait)
# Now should have the server
servers = router.fetch_live_servers()
assert len(
servers) == 1, "Should have one server after polling interval"
assert servers[
0] == server_url, "Server URL should match what was added"
# Remove the server
mock_metadata_server.remove(server_key)
# Wait for polling interval to pass
wait_time = poll_interval * 1.5
logger.info(
f"Waiting {wait_time} seconds for server removal to be detected...")
time.sleep(wait_time)
# Should now be empty again
servers = router.fetch_live_servers()
assert len(
servers) == 0, "Should have no servers after removal and waiting"
finally:
# Clean up
asyncio.run(router.stop_server_monitoring())
@pytest.mark.slow
def test_server_health_check(mock_metadata_server):
"""Test that unhealthy servers are filtered out"""
# Create router with mock metadata server
poll_interval = 5 # seconds
router = RoundRobinRouter(server_role="context",
metadata_servers=[mock_metadata_server])
# Start server monitoring
asyncio.run(router.start_server_monitoring(poll_interval=poll_interval))
try:
# Add two servers
server_key1 = "servers/context/server1"
server_url1 = "http://localhost:8001"
mock_metadata_server.add_server(server_key1, {"url": server_url1})
server_key2 = "servers/context/server2"
server_url2 = "http://localhost:8002"
mock_metadata_server.add_server(server_key2, {"url": server_url2})
# Wait for the polling interval to pass
wait_time = poll_interval * 1.5
logger.info(
f"Waiting {wait_time} seconds for servers to be detected...")
time.sleep(wait_time)
# Mock the is_server_healthy method to simulate one server being down
with mock.patch.object(router, 'is_server_healthy') as mock_is_healthy:
# Only the second server is "healthy"
mock_is_healthy.side_effect = lambda url: url == server_url2
# Fetch servers with health check
servers = router.fetch_live_servers(check_health=True)
assert len(servers) == 1, "Should have one healthy server"
assert servers[
0] == server_url2, "Only healthy server should be returned"
finally:
# Clean up
asyncio.run(router.stop_server_monitoring())
@pytest.mark.slow
def test_load_balancing_router_fetch_servers(mock_metadata_server):
"""Test that LoadBalancingRouter fetches servers correctly"""
# Create router with mock metadata server
poll_interval = 10 # seconds
router = LoadBalancingRouter(server_role="context",
metadata_servers=[mock_metadata_server])
# Start server monitoring
asyncio.run(router.start_server_monitoring(poll_interval=poll_interval))
try:
# Add two servers
server_key1 = "servers/context/server1"
server_url1 = "http://localhost:8001"
mock_metadata_server.add_server(server_key1, {"url": server_url1})
server_key2 = "servers/context/server2"
server_url2 = "http://localhost:8002"
mock_metadata_server.add_server(server_key2, {"url": server_url2})
# Wait for the polling interval to pass
wait_time = poll_interval * 1.5
logger.info(
f"Waiting {wait_time} seconds for servers to be detected...")
time.sleep(wait_time)
# Fetch servers
servers = router.fetch_live_servers()
assert len(servers) == 2, "Should have two servers after waiting"
# Remove all servers
mock_metadata_server.remove(server_key1)
mock_metadata_server.remove(server_key2)
# Wait for the polling interval to pass
logger.info(
f"Waiting {wait_time} seconds for server removals to be detected..."
)
time.sleep(wait_time)
# Test handling of no servers
servers = router.fetch_live_servers()
assert len(
servers
) == 0, "Should have no servers after removing all and waiting"
# Test get_next_server with no servers should raise ValueError
with pytest.raises(ValueError):
router.get_next_server()
finally:
# Clean up
asyncio.run(router.stop_server_monitoring())