mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: Add integration of etcd (#3738)
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Signed-off-by: BatshevaBlack <132911331+BatshevaBlack@users.noreply.github.com> Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Co-authored-by: Batsheva Black <bblack@login-eos01.eos.clusters.nvidia.com> Co-authored-by: BatshevaBlack <132911331+BatshevaBlack@users.noreply.github.com>
This commit is contained in:
parent
3fe4a1842a
commit
ae9a6cf24f
@ -55,3 +55,4 @@ nvtx
|
||||
matplotlib # FIXME: this is added to make nvtx happy
|
||||
meson
|
||||
ninja
|
||||
etcd3
|
||||
|
||||
@ -18,7 +18,9 @@ from tensorrt_llm.llmapi import (LLM, BuildConfig, CapacitySchedulerPolicy,
|
||||
DynamicBatchConfig, KvCacheConfig,
|
||||
SchedulerConfig)
|
||||
from tensorrt_llm.llmapi.disagg_utils import (CtxGenServerConfig,
|
||||
parse_disagg_config_file)
|
||||
MetadataServerConfig, ServerRole,
|
||||
parse_disagg_config_file,
|
||||
parse_metadata_server_config_file)
|
||||
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict
|
||||
from tensorrt_llm.llmapi.mpi_session import find_free_port
|
||||
from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory
|
||||
@ -127,7 +129,12 @@ def get_llm_args(model: str,
|
||||
return llm_args, llm_args_extra_dict
|
||||
|
||||
|
||||
def launch_server(host: str, port: int, llm_args: dict):
|
||||
def launch_server(host: str,
|
||||
port: int,
|
||||
llm_args: dict,
|
||||
metadata_server_cfg: Optional[MetadataServerConfig] = None,
|
||||
server_role: Optional[ServerRole] = None):
|
||||
|
||||
backend = llm_args["backend"]
|
||||
model = llm_args["model"]
|
||||
|
||||
@ -136,7 +143,10 @@ def launch_server(host: str, port: int, llm_args: dict):
|
||||
else:
|
||||
llm = LLM(**llm_args)
|
||||
|
||||
server = OpenAIServer(llm=llm, model=model)
|
||||
server = OpenAIServer(llm=llm,
|
||||
model=model,
|
||||
server_role=server_role,
|
||||
metadata_server_cfg=metadata_server_cfg)
|
||||
|
||||
asyncio.run(server(host, port))
|
||||
|
||||
@ -228,6 +238,16 @@ def launch_server(host: str, port: int, llm_args: dict):
|
||||
default=None,
|
||||
help="[Experimental] Specify the parser for reasoning models.",
|
||||
)
|
||||
@click.option("--metadata_server_config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to metadata server config file")
|
||||
@click.option(
|
||||
"--server_role",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Server role. Specify this value only if running in disaggregated mode."
|
||||
)
|
||||
def serve(model: str, tokenizer: Optional[str], host: str, port: int,
|
||||
log_level: str, backend: str, max_beam_width: int,
|
||||
max_batch_size: int, max_num_tokens: int, max_seq_len: int,
|
||||
@ -235,8 +255,9 @@ def serve(model: str, tokenizer: Optional[str], host: str, port: int,
|
||||
cluster_size: Optional[int], gpus_per_node: Optional[int],
|
||||
kv_cache_free_gpu_memory_fraction: float,
|
||||
num_postprocess_workers: int, trust_remote_code: bool,
|
||||
extra_llm_api_options: Optional[str],
|
||||
reasoning_parser: Optional[str]):
|
||||
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
|
||||
metadata_server_config_file: Optional[str],
|
||||
server_role: Optional[str]):
|
||||
"""Running an OpenAI API compatible server
|
||||
|
||||
MODEL: model name | HF checkpoint path | TensorRT engine path
|
||||
@ -267,7 +288,16 @@ def serve(model: str, tokenizer: Optional[str], host: str, port: int,
|
||||
llm_args_extra_dict = yaml.safe_load(f)
|
||||
llm_args = update_llm_args_with_extra_dict(llm_args, llm_args_extra_dict)
|
||||
|
||||
launch_server(host, port, llm_args)
|
||||
metadata_server_cfg = parse_metadata_server_config_file(
|
||||
metadata_server_config_file)
|
||||
|
||||
if metadata_server_cfg is not None:
|
||||
try:
|
||||
server_role = ServerRole[server_role.upper()]
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid server role: {server_role}. " \
|
||||
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
|
||||
launch_server(host, port, llm_args, metadata_server_cfg, server_role)
|
||||
|
||||
|
||||
def get_ctx_gen_server_urls(
|
||||
@ -289,6 +319,11 @@ def get_ctx_gen_server_urls(
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific option for disaggregated mode.")
|
||||
@click.option("-m",
|
||||
"--metadata_server_config_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to metadata server config file")
|
||||
@click.option("-t",
|
||||
"--server_start_timeout",
|
||||
type=int,
|
||||
@ -299,15 +334,27 @@ def get_ctx_gen_server_urls(
|
||||
type=int,
|
||||
default=180,
|
||||
help="Request timeout")
|
||||
def disaggregated(config_file: Optional[str], server_start_timeout: int,
|
||||
request_timeout: int):
|
||||
@click.option("-l",
|
||||
'--log_level',
|
||||
type=click.Choice(severity_map.keys()),
|
||||
default='info',
|
||||
help="The logging level.")
|
||||
def disaggregated(config_file: Optional[str],
|
||||
metadata_server_config_file: Optional[str],
|
||||
server_start_timeout: int, request_timeout: int,
|
||||
log_level: str):
|
||||
"""Running server in disaggregated mode"""
|
||||
|
||||
logger.set_level(log_level)
|
||||
|
||||
disagg_cfg = parse_disagg_config_file(config_file)
|
||||
|
||||
ctx_server_urls, gen_server_urls = get_ctx_gen_server_urls(
|
||||
disagg_cfg.server_configs)
|
||||
|
||||
metadata_server_cfg = parse_metadata_server_config_file(
|
||||
metadata_server_config_file)
|
||||
|
||||
server = OpenAIDisaggServer(
|
||||
ctx_servers=ctx_server_urls,
|
||||
gen_servers=gen_server_urls,
|
||||
@ -315,7 +362,8 @@ def disaggregated(config_file: Optional[str], server_start_timeout: int,
|
||||
server_start_timeout_secs=server_start_timeout,
|
||||
ctx_router_config=disagg_cfg.ctx_router_config,
|
||||
gen_router_config=disagg_cfg.gen_router_config,
|
||||
conditional_disagg_config=disagg_cfg.conditional_disagg_config)
|
||||
conditional_disagg_config=disagg_cfg.conditional_disagg_config,
|
||||
metadata_server_cfg=metadata_server_cfg)
|
||||
|
||||
asyncio.run(server(disagg_cfg.hostname, disagg_cfg.port))
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
@ -15,6 +16,11 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class ServerRole(Enum):
|
||||
CONTEXT = 0
|
||||
GENERATION = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CtxGenServerConfig():
|
||||
type: Literal['ctx', 'gen']
|
||||
@ -28,6 +34,7 @@ class CtxGenServerConfig():
|
||||
class RouterConfig():
|
||||
type: str = "round_robin"
|
||||
args: dict = field(default_factory=dict)
|
||||
server_role: ServerRole = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -45,6 +52,14 @@ class DisaggServerConfig():
|
||||
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataServerConfig():
|
||||
server_type: Literal['etcd']
|
||||
hostname: str = "localhost"
|
||||
port: int = 2379
|
||||
health_check_timeout: float = 5.0
|
||||
|
||||
|
||||
def parse_disagg_config_file(yaml_config_file: str):
|
||||
|
||||
with open(yaml_config_file, 'r') as file:
|
||||
@ -85,6 +100,9 @@ def extract_disagg_cfg(hostname: str = 'localhost',
|
||||
|
||||
ctx_router_config = extract_router_config(context_servers)
|
||||
gen_router_config = extract_router_config(generation_servers)
|
||||
ctx_router_config.server_role = ServerRole.CONTEXT
|
||||
gen_router_config.server_role = ServerRole.GENERATION
|
||||
|
||||
conditional_disagg_config = ConditionalDisaggConfig(
|
||||
**conditional_disagg_config) if conditional_disagg_config else None
|
||||
|
||||
@ -220,3 +238,14 @@ def split_world_comm(
|
||||
)
|
||||
|
||||
return is_leader, instance_idx, sub_comm
|
||||
|
||||
|
||||
def parse_metadata_server_config_file(
|
||||
metadata_server_config_file: Optional[str]
|
||||
) -> Optional[MetadataServerConfig]:
|
||||
if metadata_server_config_file is None:
|
||||
return None
|
||||
|
||||
with open(metadata_server_config_file, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
return MetadataServerConfig(**config)
|
||||
|
||||
@ -2,7 +2,9 @@ import atexit
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import tempfile
|
||||
import time
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Literal, Optional, Sequence, Union
|
||||
@ -86,6 +88,7 @@ LLM_DOCSTRING = LLMARGS_EXPLICIT_DOCSTRING + """
|
||||
Attributes:
|
||||
tokenizer (tensorrt_llm.llmapi.tokenizer.TokenizerBase, optional): The tokenizer loaded by LLM instance, if any.
|
||||
workspace (pathlib.Path): The directory to store intermediate files.
|
||||
llm_id (str): The unique ID of the LLM instance.
|
||||
"""
|
||||
|
||||
|
||||
@ -110,6 +113,7 @@ class LLM:
|
||||
**kwargs: Any) -> None:
|
||||
|
||||
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
|
||||
self._llm_id = None
|
||||
|
||||
try:
|
||||
llm_args_cls = TorchLlmArgs if kwargs.get(
|
||||
@ -186,6 +190,16 @@ class LLM:
|
||||
def workspace(self) -> Path:
|
||||
return Path(self._workspace.name) if self._on_trt_backend else None
|
||||
|
||||
@property
|
||||
def llm_id(self) -> str:
|
||||
if self._llm_id is None:
|
||||
hostname = socket.gethostname()
|
||||
pid = os.getpid()
|
||||
timestamp = int(time.time() * 1000)
|
||||
self._llm_id = f"{hostname}-{pid}-{timestamp}"
|
||||
|
||||
return self._llm_id
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
|
||||
95
tensorrt_llm/serve/metadata_server.py
Normal file
95
tensorrt_llm/serve/metadata_server.py
Normal file
@ -0,0 +1,95 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
try:
|
||||
import etcd3
|
||||
except Exception as e:
|
||||
logger.warning(f"etcd3 is not installed correctly: {e}")
|
||||
|
||||
|
||||
class RemoteDictionary(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, key: str, value: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, key: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def keys(self) -> list[str]:
|
||||
pass
|
||||
|
||||
|
||||
class EtcdDictionary(RemoteDictionary):
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
self._client = etcd3.client(host, port)
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
return self._client.get(key)
|
||||
|
||||
def put(self, key: str, value: str):
|
||||
self._client.put(key, value)
|
||||
|
||||
def remove(self, key: str):
|
||||
self._client.delete(key)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
# TODO: Confirm the final save key format
|
||||
# This implementation assumes that key is in the
|
||||
# format of "trtllm/executor_name/key"
|
||||
unique_keys = set()
|
||||
for _, metadata in self._client.get_all():
|
||||
key = metadata.key.decode('utf-8')
|
||||
sub_keys = key.split('/')
|
||||
if len(sub_keys) >= 2:
|
||||
top_prefix = "/".join(sub_keys[:2])
|
||||
unique_keys.add(top_prefix)
|
||||
return list(unique_keys)
|
||||
|
||||
|
||||
class JsonDictionary:
|
||||
|
||||
def __init__(self, dict: RemoteDictionary):
|
||||
self._dict = dict
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
bytes_data, _ = self._dict.get(key)
|
||||
return json.loads(bytes_data.decode('utf-8'))
|
||||
|
||||
def put(self, key: str, value: str):
|
||||
self._dict.put(key, json.dumps(value))
|
||||
|
||||
def remove(self, key: str):
|
||||
self._dict.remove(key)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
return self._dict.keys()
|
||||
|
||||
|
||||
def create_metadata_server(
|
||||
metadata_server_cfg: Optional[MetadataServerConfig]
|
||||
) -> Optional[JsonDictionary]:
|
||||
|
||||
if metadata_server_cfg is None:
|
||||
return None
|
||||
|
||||
if metadata_server_cfg.server_type == 'etcd':
|
||||
dict = EtcdDictionary(host=metadata_server_cfg.hostname,
|
||||
port=metadata_server_cfg.port)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported metadata server type: {metadata_server_cfg.server_type}"
|
||||
)
|
||||
|
||||
return JsonDictionary(dict)
|
||||
@ -18,7 +18,9 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
# yapf: disable
|
||||
from tensorrt_llm.executor import CppExecutorError
|
||||
from tensorrt_llm.llmapi.disagg_utils import (ConditionalDisaggConfig,
|
||||
MetadataServerConfig,
|
||||
RouterConfig)
|
||||
from tensorrt_llm.serve.metadata_server import create_metadata_server
|
||||
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
@ -42,14 +44,17 @@ class OpenAIDisaggServer:
|
||||
server_start_timeout_secs: int = 180,
|
||||
ctx_router_config: Optional[RouterConfig] = None,
|
||||
gen_router_config: Optional[RouterConfig] = None,
|
||||
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None):
|
||||
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None,
|
||||
metadata_server_cfg: MetadataServerConfig = None):
|
||||
|
||||
self.ctx_servers = ctx_servers
|
||||
self.gen_servers = gen_servers
|
||||
self.ctx_router = create_router(ctx_router_config, ctx_servers)
|
||||
self.gen_router = create_router(gen_router_config, gen_servers)
|
||||
self.metadata_server = create_metadata_server(metadata_server_cfg)
|
||||
self.ctx_router = create_router(ctx_router_config, ctx_servers, self.metadata_server)
|
||||
self.gen_router = create_router(gen_router_config, gen_servers, self.metadata_server)
|
||||
self.conditional_disagg_config = conditional_disagg_config
|
||||
|
||||
|
||||
if (len(self.gen_servers) == 0):
|
||||
raise ValueError("At least one generation server must be provided")
|
||||
|
||||
@ -71,7 +76,19 @@ class OpenAIDisaggServer:
|
||||
|
||||
logging.info("Waiting for context and generation servers to be ready")
|
||||
await self.wait_for_servers_ready(server_start_timeout_secs)
|
||||
|
||||
if self.metadata_server:
|
||||
logging.info("Starting server monitoring via metadata service")
|
||||
await self.ctx_router.start_server_monitoring()
|
||||
await self.gen_router.start_server_monitoring()
|
||||
|
||||
yield
|
||||
|
||||
if self.metadata_server:
|
||||
logging.info("Stopping server monitoring via metadata service")
|
||||
await self.ctx_router.stop_server_monitoring()
|
||||
await self.gen_router.stop_server_monitoring()
|
||||
|
||||
await self.session.close() # Ensure session cleanup
|
||||
|
||||
self.app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@ -3,6 +3,7 @@ import asyncio
|
||||
import signal
|
||||
import traceback
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple
|
||||
@ -18,12 +19,14 @@ from tensorrt_llm.executor import CppExecutorError
|
||||
from tensorrt_llm.executor.postproc_worker import PostprocParams
|
||||
from tensorrt_llm.inputs import prompt_inputs
|
||||
from tensorrt_llm.llmapi import LLM
|
||||
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole
|
||||
from tensorrt_llm.llmapi.llm import RequestOutput
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.serve.chat_utils import (ConversationMessage,
|
||||
apply_chat_template,
|
||||
check_multiple_response,
|
||||
parse_chat_messages_coroutines)
|
||||
from tensorrt_llm.serve.metadata_server import create_metadata_server
|
||||
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
@ -48,9 +51,14 @@ class OpenAIServer:
|
||||
|
||||
def __init__(self,
|
||||
llm: LLM,
|
||||
model: str):
|
||||
model: str,
|
||||
server_role: Optional[ServerRole],
|
||||
metadata_server_cfg: MetadataServerConfig):
|
||||
self.llm = llm
|
||||
self.tokenizer = llm.tokenizer
|
||||
self.metadata_server = create_metadata_server(metadata_server_cfg)
|
||||
self.server_role = server_role
|
||||
self.binding_addr = None # Will be set in __call__
|
||||
try:
|
||||
hf_tokenizer_path = llm._hf_model_dir or self.tokenizer.tokenizer.name_or_path
|
||||
self.processor = AutoProcessor.from_pretrained(hf_tokenizer_path)
|
||||
@ -68,8 +76,25 @@ class OpenAIServer:
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
if self.metadata_server is not None:
|
||||
metadata = {
|
||||
"model": self.model,
|
||||
"version": VERSION,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"server_role": server_role.name,
|
||||
"url": self.binding_addr
|
||||
}
|
||||
# TODO: add more metadata
|
||||
# Register with ETCD using the existing key format
|
||||
self.metadata_server.put(f"trtllm/{self.llm.llm_id}", metadata)
|
||||
logger.info(f"trtllm/{self.llm.llm_id} is registered")
|
||||
|
||||
# terminate rank0 worker
|
||||
yield
|
||||
|
||||
if self.metadata_server is not None:
|
||||
self.metadata_server.remove(f"trtllm/{self.llm.llm_id}")
|
||||
logger.info(f"trtllm/{self.llm.llm_id} is unregistered")
|
||||
self.llm.shutdown()
|
||||
|
||||
self.app = FastAPI(lifespan=lifespan)
|
||||
@ -403,11 +428,12 @@ class OpenAIServer:
|
||||
# If internal executor error is raised, shutdown the server
|
||||
signal.raise_signal(signal.SIGINT)
|
||||
except Exception as e:
|
||||
print(f"Encountered an exception: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def __call__(self, host, port):
|
||||
# Store the binding address for server registration
|
||||
self.binding_addr = f"http://{host}:{port}"
|
||||
config = uvicorn.Config(self.app,
|
||||
host=host,
|
||||
port=port,
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
import asyncio
|
||||
import heapq
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tensorrt_llm.bindings.internal.batch_manager import (BlockKey,
|
||||
BlockKeyHasher)
|
||||
from tensorrt_llm.llmapi.disagg_utils import RouterConfig
|
||||
from tensorrt_llm.llmapi.disagg_utils import RouterConfig, ServerRole
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.serve.metadata_server import JsonDictionary
|
||||
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
|
||||
@ -62,6 +64,13 @@ class ServerState:
|
||||
self._num_active_requests -= 1
|
||||
self._num_active_tokens -= num_tokens
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
try:
|
||||
async with self._session.get(self._server + "/health") as response:
|
||||
return response.status == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class KvCacheAwareServerState(ServerState):
|
||||
|
||||
@ -135,8 +144,24 @@ class KvCacheAwareServerState(ServerState):
|
||||
|
||||
class Router(ABC):
|
||||
|
||||
def __init__(self, servers: list[str] = None):
|
||||
self._servers = servers
|
||||
def __init__(self,
|
||||
server_role: ServerRole,
|
||||
servers: List[str] = None,
|
||||
metadata_server: JsonDictionary = None):
|
||||
self._servers = servers or []
|
||||
self._metadata_server = metadata_server
|
||||
self._server_role = server_role
|
||||
self._lock = asyncio.Lock()
|
||||
self._monitor_task = None
|
||||
self._session = None
|
||||
self._health_check_timeout = 5.0 # Default timeout in seconds
|
||||
|
||||
def _on_servers_updated(self, old_servers, new_servers):
|
||||
"""Called when the server list changes. Override in subclasses to handle index resets.
|
||||
Args:
|
||||
old_servers: The previous server list
|
||||
new_servers: The new server list
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
||||
@ -146,16 +171,276 @@ class Router(ABC):
|
||||
async def finish_request(self, request: OpenAIRequest):
|
||||
pass
|
||||
|
||||
async def start_server_monitoring(self, poll_interval: int = 10):
|
||||
"""Start monitoring servers update from metadata service"""
|
||||
if not self._metadata_server:
|
||||
logger.info(
|
||||
"No metadata server configured, skipping server monitoring")
|
||||
return
|
||||
|
||||
# Create a session for health checks if it doesn't exist
|
||||
if not self._session:
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
logger.info(
|
||||
f"Starting server monitoring for {self._server_role} servers")
|
||||
self._monitor_task = asyncio.create_task(
|
||||
self._monitor_servers(poll_interval))
|
||||
|
||||
async def stop_server_monitoring(self):
|
||||
"""Stop monitoring servers update from metadata service"""
|
||||
if self._monitor_task:
|
||||
logger.info(
|
||||
f"Stopping server monitoring for {self._server_role} servers")
|
||||
self._monitor_task.cancel()
|
||||
try:
|
||||
await self._monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._monitor_task = None
|
||||
|
||||
# Close session when stopping monitoring
|
||||
await self.close_session()
|
||||
|
||||
async def close_session(self):
|
||||
if self._session:
|
||||
try:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
logger.debug("HTTP session closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing session: {e}")
|
||||
self._session = None
|
||||
|
||||
async def _monitor_servers(self, poll_interval: int = 10):
|
||||
while True:
|
||||
try:
|
||||
if self._metadata_server:
|
||||
# Get servers from metadata
|
||||
server_key_map = await self.fetch_live_servers()
|
||||
|
||||
# Check health and get live servers
|
||||
live_servers = await self.check_servers_health(
|
||||
server_key_map)
|
||||
|
||||
# Filter by server role if needed
|
||||
role_specific_servers = self._filter_servers_by_role(
|
||||
live_servers, server_key_map)
|
||||
|
||||
# Use filtered servers if available
|
||||
final_servers = role_specific_servers if role_specific_servers else []
|
||||
|
||||
# Update server list
|
||||
async with self._lock:
|
||||
if final_servers != self._servers:
|
||||
num_old_servers = len(self._servers)
|
||||
old_servers = self._servers.copy()
|
||||
self._servers = final_servers
|
||||
num_new_servers = len(self._servers)
|
||||
|
||||
# Call handler for server list changes
|
||||
self._on_servers_updated(old_servers, self._servers)
|
||||
|
||||
logger.info(
|
||||
f"Updated {self._server_role} server list: {num_old_servers} -> {num_new_servers} servers"
|
||||
)
|
||||
if logger.level == "debug" and self._servers:
|
||||
for server in self._servers:
|
||||
logger.debug(f" - {server}")
|
||||
else:
|
||||
logger.debug(
|
||||
f"No change in {self._server_role} server list: {len(self._servers)} servers"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in server monitoring: {e}")
|
||||
|
||||
# Wait before next poll
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
def _filter_servers_by_role(self, servers, server_key_map):
|
||||
"""Filter servers by role (context or generation)"""
|
||||
if not self._metadata_server or not servers:
|
||||
return []
|
||||
|
||||
filtered_servers = []
|
||||
# Invert to get {url: key} for lookup
|
||||
url_to_key = {url: key for key, url in server_key_map.items()}
|
||||
|
||||
for server_url in servers:
|
||||
key = url_to_key.get(server_url)
|
||||
if key:
|
||||
server_metadata = self._metadata_server.get(key)
|
||||
if server_metadata:
|
||||
# Use either server_type or server_role field
|
||||
server_type = server_metadata.get('server_type', '').lower()
|
||||
if not server_type:
|
||||
server_type = server_metadata.get('server_role',
|
||||
'').lower()
|
||||
|
||||
# Extract port for visibility
|
||||
parts = server_url.split(':')
|
||||
if len(parts) >= 3:
|
||||
parts[2]
|
||||
|
||||
# Check if server type matches our role
|
||||
if (self._server_role == ServerRole.CONTEXT and server_type == 'context') or \
|
||||
(self._server_role == ServerRole.GENERATION and server_type == 'generation'):
|
||||
filtered_servers.append(server_url)
|
||||
|
||||
return filtered_servers
|
||||
|
||||
async def fetch_live_servers(self) -> Dict[str, str]:
|
||||
"""Fetch all servers from metadata service and return {key: url} mapping"""
|
||||
if not self._metadata_server:
|
||||
# Only use static servers if no metadata server
|
||||
return {server: "" for server in self._servers}
|
||||
|
||||
# If metadata server is available, ignore static server list entirely
|
||||
server_key_map = {}
|
||||
try:
|
||||
# Get all keys from the metadata server
|
||||
all_keys = self._metadata_server.keys()
|
||||
logger.debug(f"Found {len(all_keys)} keys in metadata server")
|
||||
|
||||
# Filter keys that start with 'trtllm/' and extract server metadata
|
||||
matching_keys = 0
|
||||
for key in all_keys:
|
||||
if key.startswith('trtllm/'):
|
||||
matching_keys += 1
|
||||
server_metadata = self._metadata_server.get(key)
|
||||
if server_metadata and isinstance(
|
||||
server_metadata, dict) and 'url' in server_metadata:
|
||||
server_key_map[key] = server_metadata['url']
|
||||
|
||||
# Check if metadata includes health check timeout
|
||||
if 'health_check_timeout' in server_metadata:
|
||||
try:
|
||||
self._health_check_timeout = float(
|
||||
server_metadata['health_check_timeout'])
|
||||
logger.debug(
|
||||
f"Using health check timeout: {self._health_check_timeout}s"
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid health_check_timeout value: {server_metadata['health_check_timeout']}"
|
||||
)
|
||||
|
||||
if server_key_map:
|
||||
logger.info(
|
||||
f"Using {len(server_key_map)} servers from metadata service"
|
||||
)
|
||||
else:
|
||||
logger.warning("No servers found in metadata service")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching servers from metadata service: {e}")
|
||||
|
||||
return server_key_map
|
||||
|
||||
async def check_servers_health(self,
|
||||
server_key_map: Dict[str, str]) -> List[str]:
|
||||
"""Check health of servers and remove dead ones from metadata service"""
|
||||
live_servers = []
|
||||
dead_servers = []
|
||||
|
||||
try:
|
||||
# Check health of each server
|
||||
for key, server_url in server_key_map.items():
|
||||
# First attempt - no printing errors
|
||||
is_healthy = await self._check_server_health(server_url,
|
||||
silent=True)
|
||||
|
||||
# If first attempt failed, try again before declaring server dead
|
||||
if not is_healthy:
|
||||
# Second attempt - will print errors if it fails
|
||||
is_healthy = await self._check_server_health(server_url,
|
||||
silent=False)
|
||||
|
||||
if not is_healthy:
|
||||
# Only now add to dead servers
|
||||
dead_servers.append((key, server_url))
|
||||
logger.warning(
|
||||
f"Server {server_url} is not healthy after retry - removing"
|
||||
)
|
||||
else:
|
||||
live_servers.append(server_url)
|
||||
else:
|
||||
live_servers.append(server_url)
|
||||
|
||||
# Remove dead servers from etcd
|
||||
for key, dead_server in dead_servers:
|
||||
try:
|
||||
logger.info(
|
||||
f"Removing dead server {dead_server} from metadata server"
|
||||
)
|
||||
self._metadata_server.remove(key)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error removing dead server from metadata service: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking server health: {e}")
|
||||
|
||||
return live_servers if live_servers else self._servers
|
||||
|
||||
async def _check_server_health(self, server_url, silent=False) -> bool:
|
||||
"""Check if a server is healthy by querying its health endpoint"""
|
||||
if not self._session:
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{server_url}/health",
|
||||
timeout=self._health_check_timeout) as response:
|
||||
if response.status != 200:
|
||||
if not silent:
|
||||
logger.warning(
|
||||
f"Server {server_url} is not healthy (status: {response.status})"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
if not silent:
|
||||
logger.warning(f"Server {server_url} is not reachable: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class RoundRobinRouter(Router):
|
||||
|
||||
def __init__(self, servers: list[str] = None, **kwargs):
|
||||
super().__init__(servers)
|
||||
def __init__(self,
|
||||
server_role: ServerRole,
|
||||
servers: List[str] = None,
|
||||
metadata_server: JsonDictionary = None,
|
||||
**kwargs):
|
||||
super().__init__(server_role, servers, metadata_server)
|
||||
self._server_idx = 0
|
||||
|
||||
def _on_servers_updated(self, old_servers, new_servers):
|
||||
"""Reset the index when servers are removed to prevent index out of bounds errors."""
|
||||
if len(new_servers) < len(old_servers):
|
||||
# Servers were removed, reset the index
|
||||
self._server_idx = 0
|
||||
elif self._server_idx >= len(new_servers):
|
||||
# Safety check: ensure index is always within bounds
|
||||
self._server_idx = 0
|
||||
|
||||
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
||||
server = self._servers[self._server_idx]
|
||||
self._server_idx = (self._server_idx + 1) % len(self._servers)
|
||||
if not self._servers:
|
||||
if self._metadata_server:
|
||||
raise ValueError(
|
||||
f"No {self._server_role} servers available in metadata service"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"No {self._server_role} servers available")
|
||||
|
||||
async with self._lock:
|
||||
# Safety check: ensure index is within bounds
|
||||
if self._server_idx >= len(self._servers):
|
||||
self._server_idx = 0
|
||||
|
||||
server = self._servers[self._server_idx]
|
||||
self._server_idx = (self._server_idx + 1) % len(self._servers)
|
||||
return server, {}
|
||||
|
||||
async def finish_request(self, request: OpenAIRequest):
|
||||
@ -165,11 +450,12 @@ class RoundRobinRouter(Router):
|
||||
class LoadBalancingRouter(Router):
|
||||
|
||||
def __init__(self,
|
||||
servers: list[str] = None,
|
||||
server_role: ServerRole,
|
||||
servers: List[str] = None,
|
||||
metadata_server: JsonDictionary = None,
|
||||
use_tokens: bool = False,
|
||||
**kwargs):
|
||||
super().__init__(servers)
|
||||
self._lock = asyncio.Lock()
|
||||
super().__init__(server_role, servers, metadata_server)
|
||||
# Load map between servers and their number of tokens processed
|
||||
self._server_state = {}
|
||||
self._server_load_heap = []
|
||||
@ -180,6 +466,25 @@ class LoadBalancingRouter(Router):
|
||||
self._use_tokens = use_tokens
|
||||
self._init_heap()
|
||||
|
||||
def _on_servers_updated(self, old_servers, new_servers):
|
||||
"""Rebuild the heap when the server list changes."""
|
||||
# Keep the state for servers that still exist
|
||||
current_state = {}
|
||||
for server in new_servers:
|
||||
if server in self._server_state:
|
||||
# Keep existing state
|
||||
current_state[server] = self._server_state[server]
|
||||
else:
|
||||
# Initialize new server state
|
||||
current_state[server] = ServerState(server, self._use_tokens)
|
||||
|
||||
# Update state and rebuild heap
|
||||
self._server_state = current_state
|
||||
self._server_load_heap = []
|
||||
for server in new_servers:
|
||||
heapq.heappush(self._server_load_heap,
|
||||
(self._get_server_load(server), server))
|
||||
|
||||
def _init_heap(self):
|
||||
for server in self._servers:
|
||||
self._server_state[server] = ServerState(server, self._use_tokens)
|
||||
@ -187,6 +492,14 @@ class LoadBalancingRouter(Router):
|
||||
(self._get_server_load(server), server))
|
||||
|
||||
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
||||
if not self._servers:
|
||||
if self._metadata_server:
|
||||
raise ValueError(
|
||||
f"No {self._server_role} servers available in metadata service"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"No {self._server_role} servers available")
|
||||
|
||||
async with self._lock:
|
||||
server = heapq.heappop(self._server_load_heap)[1]
|
||||
await self._server_state[server].increment_load(request)
|
||||
@ -220,18 +533,20 @@ def block_key_hasher(token_ids: list[int],
|
||||
class KvCacheAwareRouter(Router):
|
||||
|
||||
def __init__(self,
|
||||
server_role: ServerRole = None,
|
||||
servers: list[str] = None,
|
||||
metadata_server: JsonDictionary = None,
|
||||
use_tokens: bool = False,
|
||||
max_batch_size: int = 64,
|
||||
tokens_per_block: int = 32,
|
||||
**kwargs):
|
||||
super().__init__(servers)
|
||||
super().__init__(server_role, servers, metadata_server)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Load map between servers and their number of tokens processed
|
||||
self._server_state: dict[str, KvCacheAwareServerState] = {
|
||||
server: KvCacheAwareServerState(server, use_tokens)
|
||||
for server in servers
|
||||
for server in servers or []
|
||||
}
|
||||
|
||||
# Routing table to map requests to servers
|
||||
@ -313,7 +628,8 @@ class KvCacheAwareRouter(Router):
|
||||
|
||||
|
||||
def create_router(router_config: Optional[RouterConfig],
|
||||
servers: list[str]) -> Router:
|
||||
servers: List[str],
|
||||
metadata_server: JsonDictionary = None) -> Router:
|
||||
"""
|
||||
Factory function to create different types of router instances.
|
||||
|
||||
@ -331,7 +647,8 @@ def create_router(router_config: Optional[RouterConfig],
|
||||
ValueError: If an unsupported router type is provided
|
||||
"""
|
||||
if router_config is None:
|
||||
return RoundRobinRouter(servers)
|
||||
# Create a default router without server_role
|
||||
return RoundRobinRouter(None, servers)
|
||||
|
||||
router_map = {
|
||||
"round_robin": RoundRobinRouter,
|
||||
@ -345,4 +662,6 @@ def create_router(router_config: Optional[RouterConfig],
|
||||
raise ValueError(f"Unsupported router type: {router_type}. "
|
||||
f"Supported types are: {list(router_map.keys())}")
|
||||
|
||||
return router_class(servers, **router_config.args)
|
||||
# Pass server_role as the first argument
|
||||
return router_class(router_config.server_role, servers, metadata_server,
|
||||
**router_config.args)
|
||||
|
||||
426
tests/integration/defs/disaggregated/test_disaggregated_etcd.py
Normal file
426
tests/integration/defs/disaggregated/test_disaggregated_etcd.py
Normal file
@ -0,0 +1,426 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
# Configuration file paths
|
||||
EXAMPLES_DIR = "examples/disaggregated"
|
||||
CLIENTS_DIR = f"{EXAMPLES_DIR}/clients"
|
||||
CONTEXT_CONFIG_FILE = f"{EXAMPLES_DIR}/context_extra-llm-api-config.yml"
|
||||
GENERATION_CONFIG_FILE = f"{EXAMPLES_DIR}/gen_extra-llm-api-config.yml"
|
||||
ETCD_CONFIG_FILE = f"{EXAMPLES_DIR}/etcd_config.yaml"
|
||||
DISAGG_CONFIG_FILE = f"{EXAMPLES_DIR}/disagg_config.yaml"
|
||||
CLIENT_SCRIPT_FILE = f"{CLIENTS_DIR}/disagg_client.py"
|
||||
PROMPTS_FILE = f"{CLIENTS_DIR}/prompts.json"
|
||||
|
||||
|
||||
def kill_automated_disaggregated_processes():
|
||||
"""Kill any existing automated disaggregated processes."""
|
||||
try:
|
||||
subprocess.run(['pkill', '-9', '-f', 'trtllm-serve'], check=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def cleanup_automated_output_files():
|
||||
"""Clean up output files from previous runs."""
|
||||
for file in [
|
||||
'output.json', 'output_streaming.json', 'output_workers.log',
|
||||
'output_disagg.log'
|
||||
]:
|
||||
try:
|
||||
os.remove(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def start_context_server(config,
|
||||
gpu_id: int,
|
||||
port: int,
|
||||
env=None) -> subprocess.Popen:
|
||||
"""Start a context server on specified GPU and port."""
|
||||
cmd = [
|
||||
"trtllm-serve", config['model_path'], "--host", "localhost", "--port",
|
||||
str(port), "--backend", "pytorch", "--extra_llm_api_options",
|
||||
f"./{CONTEXT_CONFIG_FILE}", "--metadata_server_config_file",
|
||||
ETCD_CONFIG_FILE, "--server_role", "CONTEXT"
|
||||
]
|
||||
|
||||
server_env = env.copy() if env else os.environ.copy()
|
||||
server_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
server_env["TRTLLM_USE_UCX_KVCACHE"] = "1"
|
||||
|
||||
logger.info(f"Starting CONTEXT server on GPU {gpu_id} (port {port})...")
|
||||
process = subprocess.Popen(cmd,
|
||||
env=server_env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1)
|
||||
return process
|
||||
|
||||
|
||||
def start_generation_server(config,
|
||||
gpu_id: int,
|
||||
port: int,
|
||||
env=None) -> subprocess.Popen:
|
||||
"""Start a generation server on specified GPU and port."""
|
||||
cmd = [
|
||||
"trtllm-serve", config['model_path'], "--host", "localhost", "--port",
|
||||
str(port), "--backend", "pytorch", "--extra_llm_api_options",
|
||||
f"./{GENERATION_CONFIG_FILE}", "--metadata_server_config_file",
|
||||
ETCD_CONFIG_FILE, "--server_role", "GENERATION"
|
||||
]
|
||||
|
||||
server_env = env.copy() if env else os.environ.copy()
|
||||
server_env["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
server_env["TRTLLM_USE_UCX_KVCACHE"] = "1"
|
||||
|
||||
logger.info(f"Starting GENERATION server on GPU {gpu_id} (port {port})...")
|
||||
process = subprocess.Popen(cmd,
|
||||
env=server_env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1)
|
||||
return process
|
||||
|
||||
|
||||
def start_disaggregated_service(config, env=None) -> subprocess.Popen:
|
||||
"""Launch the disaggregated service."""
|
||||
cmd = [
|
||||
"trtllm-serve", "disaggregated", "-c", DISAGG_CONFIG_FILE, "-m",
|
||||
ETCD_CONFIG_FILE
|
||||
]
|
||||
|
||||
logger.info("Launching disaggregated service...")
|
||||
process = subprocess.Popen(cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1)
|
||||
return process
|
||||
|
||||
|
||||
def wait_for_server_health(port: int, timeout: int = 120) -> bool:
|
||||
"""Wait for server to be healthy by checking /health endpoint."""
|
||||
url = f"http://localhost:{port}/health"
|
||||
start_time = time.time()
|
||||
logger.info(f"Waiting for server on port {port} to be healthy...")
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"Server on port {port} is healthy")
|
||||
return True
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(2)
|
||||
|
||||
logger.error(f"Timed out waiting for server on port {port}")
|
||||
return False
|
||||
|
||||
|
||||
def run_client_test(config, env=None) -> bool:
|
||||
"""Run the disaggregated client test."""
|
||||
cmd = [
|
||||
"python3", f"./{CLIENT_SCRIPT_FILE}", "-c", DISAGG_CONFIG_FILE, "-p",
|
||||
f"./{PROMPTS_FILE}"
|
||||
]
|
||||
|
||||
logger.info("Running disaggregated client test...")
|
||||
result = subprocess.run(cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Client test succeeded")
|
||||
logger.info(f"Client output: {result.stdout}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Client test failed with return code {result.returncode}")
|
||||
logger.error(f"Error output: {result.stderr}")
|
||||
logger.error(f"Standard output: {result.stdout}")
|
||||
return False
|
||||
|
||||
|
||||
def kill_server_by_port(port: int) -> bool:
|
||||
"""Find and kill a process by port using lsof."""
|
||||
try:
|
||||
# Find PID using port
|
||||
cmd = ["lsof", "-t", f"-i:{port}"]
|
||||
result = subprocess.run(cmd, stdout=subprocess.PIPE, text=True)
|
||||
|
||||
if result.stdout.strip():
|
||||
pid = int(result.stdout.strip())
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
logger.info(f"Killed process {pid} on port {port}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"No process found on port {port}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error killing process on port {port}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def cleanup_processes(processes):
|
||||
"""Kill all started processes."""
|
||||
logger.info("Cleaning up all processes...")
|
||||
|
||||
for name, process in processes.items():
|
||||
if process.poll() is None: # Still running
|
||||
logger.info(f"Terminating {name} (PID: {process.pid})")
|
||||
try:
|
||||
process.terminate()
|
||||
process.wait(timeout=5)
|
||||
except (subprocess.TimeoutExpired, ProcessLookupError):
|
||||
logger.warning(f"Force killing {name} (PID: {process.pid})")
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
|
||||
def start_etcd_server(working_dir, env=None) -> subprocess.Popen:
|
||||
"""Start etcd server."""
|
||||
cmd = ["etcd"]
|
||||
|
||||
logger.info("Starting etcd server...")
|
||||
process = subprocess.Popen(cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=working_dir)
|
||||
return process
|
||||
|
||||
|
||||
def cleanup_etcd_data(env=None):
|
||||
"""Clean up etcd data using etcdctl."""
|
||||
cmd = ["etcdctl", "del", "--prefix", "trtllm/"]
|
||||
|
||||
logger.info("Cleaning etcd data...")
|
||||
result = subprocess.run(cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Successfully cleaned etcd data")
|
||||
else:
|
||||
logger.warning(f"Failed to clean etcd data: {result.stderr}")
|
||||
|
||||
|
||||
def create_config_files(config):
|
||||
"""Create necessary configuration files"""
|
||||
# Create context config file
|
||||
context_config_content = """pytorch_backend_config:
|
||||
disable_overlap_scheduler: True
|
||||
cache_transceiver_config:
|
||||
max_num_tokens: 2048"""
|
||||
|
||||
with open(CONTEXT_CONFIG_FILE, 'w') as file:
|
||||
file.write(context_config_content)
|
||||
|
||||
# Create generation config file
|
||||
generation_config_content = """cache_transceiver_config:
|
||||
max_num_tokens: 2048"""
|
||||
|
||||
with open(GENERATION_CONFIG_FILE, 'w') as file:
|
||||
file.write(generation_config_content)
|
||||
|
||||
# Create etcd config file
|
||||
etcd_config_content = """server_type: "etcd"
|
||||
hostname: "localhost"
|
||||
port: 2379
|
||||
health_check_timeout: 5.0"""
|
||||
|
||||
with open(ETCD_CONFIG_FILE, 'w') as file:
|
||||
file.write(etcd_config_content)
|
||||
|
||||
disagg_config_content = """hostname: localhost
|
||||
port: 8000
|
||||
backend: pytorch
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
"""
|
||||
|
||||
with open(DISAGG_CONFIG_FILE, 'w') as file:
|
||||
file.write(disagg_config_content)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def run_automated_disaggregated_test(example_dir, env=None, cwd=None):
|
||||
"""Run automated disaggregated test with given configuration."""
|
||||
kill_automated_disaggregated_processes()
|
||||
cleanup_automated_output_files()
|
||||
|
||||
config = {"model_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
|
||||
|
||||
# Create configuration files
|
||||
create_config_files(config)
|
||||
processes = {}
|
||||
|
||||
try:
|
||||
# Step 1: Start etcd server
|
||||
logger.info("Step 1: Starting etcd server...")
|
||||
processes["etcd"] = start_etcd_server(".", env=env)
|
||||
time.sleep(5) # Give etcd time to start
|
||||
|
||||
# Step 2: Clean etcd data
|
||||
logger.info("Step 2: Cleaning etcd data...")
|
||||
cleanup_etcd_data(env=env)
|
||||
|
||||
# Step 3: Start context server on GPU 0 (port 8001)
|
||||
logger.info("Step 3: Starting context server on GPU 0 (port 8001)...")
|
||||
processes["context_8001"] = start_context_server(config,
|
||||
gpu_id=0,
|
||||
port=8001,
|
||||
env=env)
|
||||
|
||||
# Step 4: Start generation server on GPU 1 (port 8002)
|
||||
logger.info(
|
||||
"Step 4: Starting generation server on GPU 1 (port 8002)...")
|
||||
processes["generation_8002"] = start_generation_server(config,
|
||||
gpu_id=1,
|
||||
port=8002,
|
||||
env=env)
|
||||
|
||||
# Step 5: Wait till gen and context ready
|
||||
logger.info(
|
||||
"Step 5: Waiting for context and generation servers to be ready...")
|
||||
if not wait_for_server_health(port=8001):
|
||||
logger.error("Context server on port 8001 failed to start")
|
||||
return False
|
||||
if not wait_for_server_health(port=8002):
|
||||
logger.error("Generation server on port 8002 failed to start")
|
||||
return False
|
||||
|
||||
# Step 6: Start disaggregated service
|
||||
logger.info("Step 6: Starting disaggregated service...")
|
||||
processes["disagg_service"] = start_disaggregated_service(config,
|
||||
env=env)
|
||||
|
||||
# Step 7: Wait for disaggregated service and run first client test
|
||||
logger.info(
|
||||
"Step 7: Waiting for disaggregated service and running first client test..."
|
||||
)
|
||||
if not wait_for_server_health(port=8000):
|
||||
logger.error("Disaggregated service failed to start")
|
||||
return False
|
||||
|
||||
first_test_success = run_client_test(config, env=env)
|
||||
if not first_test_success:
|
||||
logger.error("First client test failed")
|
||||
return False
|
||||
|
||||
# Step 8: Start second context server on GPU 2 (port 8003)
|
||||
logger.info(
|
||||
"Step 8: Starting second context server on GPU 2 (port 8003)...")
|
||||
processes["context_8003"] = start_context_server(config,
|
||||
gpu_id=2,
|
||||
port=8003,
|
||||
env=env)
|
||||
|
||||
# Step 9: Wait till ready and then 10 seconds, run second client test
|
||||
logger.info(
|
||||
"Step 9: Waiting for second context server and running second client test..."
|
||||
)
|
||||
if not wait_for_server_health(port=8003):
|
||||
logger.error("Second context server on port 8003 failed to start")
|
||||
return False
|
||||
|
||||
logger.info("Waiting additional 10 seconds for system stabilization...")
|
||||
time.sleep(10)
|
||||
|
||||
second_test_success = run_client_test(config, env=env)
|
||||
if not second_test_success:
|
||||
logger.error("Second client test failed")
|
||||
return False
|
||||
|
||||
# Step 10: Kill 8001 process (first context server)
|
||||
logger.info("Step 10: Killing first context server (port 8001)...")
|
||||
if "context_8001" in processes:
|
||||
process = processes["context_8001"]
|
||||
if process.poll() is None:
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
kill_server_by_port(8001)
|
||||
|
||||
# Step 11: Wait a few seconds and run final client test
|
||||
logger.info(
|
||||
"Step 11: Waiting a few seconds and running final client test...")
|
||||
time.sleep(5)
|
||||
|
||||
final_test_success = run_client_test(config, env=env)
|
||||
if not final_test_success:
|
||||
logger.error("Final client test failed")
|
||||
return False
|
||||
|
||||
logger.info("✅ All automated disaggregated tests passed successfully!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during automated test: {e}")
|
||||
return False
|
||||
finally:
|
||||
cleanup_processes(processes)
|
||||
kill_automated_disaggregated_processes()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_automated_disaggregated_complete(disaggregated_test_root,
|
||||
disaggregated_example_root, llm_venv,
|
||||
llama_model_root):
|
||||
src_dst_dict = {
|
||||
llama_model_root:
|
||||
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
}
|
||||
for src, dst in src_dst_dict.items():
|
||||
if not os.path.islink(dst):
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
os.symlink(src, dst, target_is_directory=True)
|
||||
|
||||
success = run_automated_disaggregated_test(
|
||||
disaggregated_example_root,
|
||||
env=llm_venv._new_env,
|
||||
cwd=llm_venv.get_working_directory())
|
||||
assert success, "Automated disaggregated test failed"
|
||||
@ -16,7 +16,7 @@ from tensorrt_llm import logger
|
||||
from tensorrt_llm.serve.openai_protocol import (CompletionRequest,
|
||||
DisaggregatedParams)
|
||||
from tensorrt_llm.serve.router import (KvCacheAwareRouter,
|
||||
KvCacheAwareServerState,
|
||||
KvCacheAwareServerState, ServerRole,
|
||||
block_key_hasher)
|
||||
|
||||
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
@ -350,8 +350,10 @@ class KvCacheAwareRouterTester(BasicWorkerTester):
|
||||
server_start_timeout_secs: int = 180):
|
||||
super().__init__(ctx_servers, gen_servers, req_timeout_secs,
|
||||
server_start_timeout_secs)
|
||||
self.ctx_router = KvCacheAwareRouter(ctx_servers)
|
||||
self.gen_router = KvCacheAwareRouter(gen_servers)
|
||||
self.ctx_router = KvCacheAwareRouter(server_role=ServerRole.CONTEXT,
|
||||
servers=ctx_servers)
|
||||
self.gen_router = KvCacheAwareRouter(server_role=ServerRole.GENERATION,
|
||||
servers=gen_servers)
|
||||
|
||||
async def multi_round_request(self,
|
||||
session: aiohttp.ClientSession,
|
||||
|
||||
@ -140,3 +140,6 @@ properties:
|
||||
workspace:
|
||||
annotation: pathlib.Path
|
||||
default: inspect._empty
|
||||
llm_id:
|
||||
annotation: str
|
||||
default: inspect._empty
|
||||
|
||||
113
tests/unittest/disaggregated/test_remoteDictionary.py
Normal file
113
tests/unittest/disaggregated/test_remoteDictionary.py
Normal file
@ -0,0 +1,113 @@
|
||||
import unittest
|
||||
|
||||
import etcd3
|
||||
|
||||
from tensorrt_llm.serve.metadata_server import EtcdDictionary
|
||||
|
||||
|
||||
class TestEtcdDictionary(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Setup etcd connection parameters
|
||||
self.host = "localhost"
|
||||
self.port = 2379
|
||||
|
||||
# Create a clean etcd client for test setup/teardown
|
||||
self.cleanup_client = etcd3.client(host=self.host, port=self.port)
|
||||
|
||||
# Create the dictionary under test
|
||||
self.etcd_dict = EtcdDictionary(host=self.host, port=self.port)
|
||||
|
||||
# Clean up any existing test keys before each test
|
||||
self._cleanup_test_keys()
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up test keys after each test
|
||||
self._cleanup_test_keys()
|
||||
|
||||
def _cleanup_test_keys(self):
|
||||
# Helper method to remove test keys
|
||||
test_keys = [
|
||||
"trtllm/1/test_key1", "trtllm/1/test_key2", "trtllm/2/test_key3"
|
||||
]
|
||||
for key in test_keys:
|
||||
self.cleanup_client.delete(key)
|
||||
|
||||
def test_put_and_get(self):
|
||||
# Test putting and getting a value
|
||||
test_key = "trtllm/1/test_key1"
|
||||
test_value = "value1"
|
||||
|
||||
# Put the value
|
||||
self.etcd_dict.put(test_key, test_value)
|
||||
|
||||
# Get the value
|
||||
value, _ = self.etcd_dict.get(test_key)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(value.decode('utf-8'), test_value)
|
||||
|
||||
def test_remove(self):
|
||||
# Test removing a key
|
||||
test_key = "trtllm/1/test_key2"
|
||||
test_value = "value2"
|
||||
|
||||
# Setup: Put a value first
|
||||
self.etcd_dict.put(test_key, test_value)
|
||||
|
||||
# Remove the key
|
||||
self.etcd_dict.remove(test_key)
|
||||
|
||||
# Verify key is removed by trying to get it
|
||||
result = self.cleanup_client.get(test_key)
|
||||
self.assertIsNone(
|
||||
result[0]) # etcd3 returns (None, None) when key doesn't exist
|
||||
|
||||
def test_keys(self):
|
||||
# Test listing all keys
|
||||
test_data = {
|
||||
"trtllm/1/test_key1": "value1",
|
||||
"trtllm/1/test_key2": "value2",
|
||||
"trtllm/2/test_key3": "value3"
|
||||
}
|
||||
|
||||
prefix_data = {"trtllm/1": "value1", "trtllm/2": "value2"}
|
||||
|
||||
# Setup: Put multiple values
|
||||
for key, value in test_data.items():
|
||||
self.etcd_dict.put(key, value)
|
||||
|
||||
# Get all keys
|
||||
keys = self.etcd_dict.keys()
|
||||
|
||||
# Assert all test keys are present
|
||||
prefix_keys = set(prefix_data.keys())
|
||||
extract_keys = set(keys)
|
||||
self.assertEqual(prefix_keys, extract_keys)
|
||||
|
||||
def test_get_nonexistent_key(self):
|
||||
# Test getting a key that doesn't exist
|
||||
result, _ = self.etcd_dict.get("nonexistent_key")
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_put_update_existing(self):
|
||||
# Test updating an existing key
|
||||
test_key = "trtllm/1/test_key1"
|
||||
initial_value = "initial_value"
|
||||
updated_value = "updated_value"
|
||||
|
||||
# Put initial value
|
||||
self.etcd_dict.put(test_key, initial_value)
|
||||
|
||||
# Update value
|
||||
self.etcd_dict.put(test_key, updated_value)
|
||||
|
||||
# Get updated value
|
||||
value, _ = self.etcd_dict.get(test_key)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(value.decode('utf-8'), updated_value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -1,8 +1,13 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm.llmapi.disagg_utils import RouterConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DisaggregatedParams)
|
||||
@ -10,6 +15,40 @@ from tensorrt_llm.serve.router import (KvCacheAwareRouter, LoadBalancingRouter,
|
||||
RoundRobinRouter, create_router)
|
||||
|
||||
|
||||
# Mock class for metadata server
|
||||
class MockMetadataServer:
|
||||
"""Mock metadata server for testing router interactions"""
|
||||
|
||||
def __init__(self):
|
||||
self.servers = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get(self, key):
|
||||
with self.lock:
|
||||
return self.servers.get(key)
|
||||
|
||||
def put(self, key, value):
|
||||
with self.lock:
|
||||
self.servers[key] = value
|
||||
return True
|
||||
|
||||
def remove(self, key):
|
||||
with self.lock:
|
||||
if key in self.servers:
|
||||
del self.servers[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_server(self, key, url):
|
||||
with self.lock:
|
||||
self.servers[key] = url
|
||||
return True
|
||||
|
||||
def keys(self, prefix=""):
|
||||
with self.lock:
|
||||
return [k for k in self.servers.keys() if k.startswith(prefix)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def servers():
|
||||
return ["server1", "server2", "server3"]
|
||||
@ -301,3 +340,245 @@ def test_create_router(servers):
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
create_router(RouterConfig(type="unsupported_router"), servers)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metadata_server():
|
||||
return MockMetadataServer()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_fetch_live_servers_context(mock_metadata_server):
|
||||
"""Test fetching live context servers"""
|
||||
# Create router with mock metadata server
|
||||
router = RoundRobinRouter(server_role="context",
|
||||
metadata_servers=[mock_metadata_server])
|
||||
|
||||
# Start server monitoring with a shorter poll interval for testing
|
||||
# but still long enough to verify the actual behavior
|
||||
poll_interval = 10 # seconds
|
||||
asyncio.run(router.start_server_monitoring(poll_interval=poll_interval))
|
||||
|
||||
try:
|
||||
# Initial check - should be no servers
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(servers) == 0, "Should have no servers initially"
|
||||
|
||||
# Add a server
|
||||
server_key = "servers/context/server1"
|
||||
server_url = "http://localhost:8001"
|
||||
mock_metadata_server.add_server(server_key, {"url": server_url})
|
||||
|
||||
# Wait for the polling interval to pass (add 50% buffer)
|
||||
wait_time = poll_interval * 1.5
|
||||
logger.info(f"Waiting {wait_time} seconds for server to be detected...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Fetch servers again
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(
|
||||
servers) == 1, "Should have one server after adding and waiting"
|
||||
assert servers[
|
||||
0] == server_url, "Server URL should match what was added"
|
||||
|
||||
# Add another server
|
||||
server_key2 = "servers/context/server2"
|
||||
server_url2 = "http://localhost:8002"
|
||||
mock_metadata_server.add_server(server_key2, {"url": server_url2})
|
||||
|
||||
# Wait for the polling interval again
|
||||
logger.info(
|
||||
f"Waiting {wait_time} seconds for second server to be detected...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Fetch servers again
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(
|
||||
servers
|
||||
) == 2, "Should have two servers after adding second one and waiting"
|
||||
assert server_url in servers, "First server should still be present"
|
||||
assert server_url2 in servers, "Second server should be present"
|
||||
|
||||
# Remove a server
|
||||
mock_metadata_server.remove(server_key)
|
||||
|
||||
# Wait for the polling interval again
|
||||
logger.info(
|
||||
f"Waiting {wait_time} seconds for server removal to be detected...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Fetch servers again
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(
|
||||
servers
|
||||
) == 1, "Should have one server after removing one and waiting"
|
||||
assert servers[
|
||||
0] == server_url2, "Remaining server should be the second one"
|
||||
finally:
|
||||
# Clean up
|
||||
asyncio.run(router.stop_server_monitoring())
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_fetch_live_servers_with_delay(mock_metadata_server):
|
||||
"""Test fetching live servers with the actual polling delay"""
|
||||
# Create router with mock metadata server
|
||||
poll_interval = 5 # seconds
|
||||
|
||||
router = RoundRobinRouter(server_role="context",
|
||||
metadata_servers=[mock_metadata_server])
|
||||
|
||||
# Start server monitoring with shorter interval for testing
|
||||
asyncio.run(router.start_server_monitoring(poll_interval=poll_interval))
|
||||
|
||||
try:
|
||||
# Initial check - should be no servers
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(servers) == 0, "Should have no servers initially"
|
||||
|
||||
# Add a server
|
||||
server_key = "servers/context/server1"
|
||||
server_url = "http://localhost:8001"
|
||||
mock_metadata_server.add_server(server_key, {"url": server_url})
|
||||
|
||||
# Wait for a bit less than the polling interval - should still have no servers
|
||||
short_wait = poll_interval * 0.4
|
||||
logger.info(
|
||||
f"Waiting {short_wait} seconds (less than polling interval)...")
|
||||
time.sleep(short_wait)
|
||||
|
||||
# Verify server isn't discovered yet
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(
|
||||
servers
|
||||
) == 0, "Should still have no servers before polling interval completes"
|
||||
|
||||
# Wait for the polling interval to pass
|
||||
remaining_wait = poll_interval * 1.2
|
||||
logger.info(
|
||||
f"Waiting additional {remaining_wait} seconds for server to be detected..."
|
||||
)
|
||||
time.sleep(remaining_wait)
|
||||
|
||||
# Now should have the server
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(
|
||||
servers) == 1, "Should have one server after polling interval"
|
||||
assert servers[
|
||||
0] == server_url, "Server URL should match what was added"
|
||||
|
||||
# Remove the server
|
||||
mock_metadata_server.remove(server_key)
|
||||
|
||||
# Wait for polling interval to pass
|
||||
wait_time = poll_interval * 1.5
|
||||
logger.info(
|
||||
f"Waiting {wait_time} seconds for server removal to be detected...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Should now be empty again
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(
|
||||
servers) == 0, "Should have no servers after removal and waiting"
|
||||
finally:
|
||||
# Clean up
|
||||
asyncio.run(router.stop_server_monitoring())
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_server_health_check(mock_metadata_server):
|
||||
"""Test that unhealthy servers are filtered out"""
|
||||
# Create router with mock metadata server
|
||||
poll_interval = 5 # seconds
|
||||
|
||||
router = RoundRobinRouter(server_role="context",
|
||||
metadata_servers=[mock_metadata_server])
|
||||
|
||||
# Start server monitoring
|
||||
asyncio.run(router.start_server_monitoring(poll_interval=poll_interval))
|
||||
|
||||
try:
|
||||
# Add two servers
|
||||
server_key1 = "servers/context/server1"
|
||||
server_url1 = "http://localhost:8001"
|
||||
mock_metadata_server.add_server(server_key1, {"url": server_url1})
|
||||
|
||||
server_key2 = "servers/context/server2"
|
||||
server_url2 = "http://localhost:8002"
|
||||
mock_metadata_server.add_server(server_key2, {"url": server_url2})
|
||||
|
||||
# Wait for the polling interval to pass
|
||||
wait_time = poll_interval * 1.5
|
||||
logger.info(
|
||||
f"Waiting {wait_time} seconds for servers to be detected...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Mock the is_server_healthy method to simulate one server being down
|
||||
with mock.patch.object(router, 'is_server_healthy') as mock_is_healthy:
|
||||
# Only the second server is "healthy"
|
||||
mock_is_healthy.side_effect = lambda url: url == server_url2
|
||||
|
||||
# Fetch servers with health check
|
||||
servers = router.fetch_live_servers(check_health=True)
|
||||
assert len(servers) == 1, "Should have one healthy server"
|
||||
assert servers[
|
||||
0] == server_url2, "Only healthy server should be returned"
|
||||
finally:
|
||||
# Clean up
|
||||
asyncio.run(router.stop_server_monitoring())
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_load_balancing_router_fetch_servers(mock_metadata_server):
|
||||
"""Test that LoadBalancingRouter fetches servers correctly"""
|
||||
# Create router with mock metadata server
|
||||
poll_interval = 10 # seconds
|
||||
|
||||
router = LoadBalancingRouter(server_role="context",
|
||||
metadata_servers=[mock_metadata_server])
|
||||
|
||||
# Start server monitoring
|
||||
asyncio.run(router.start_server_monitoring(poll_interval=poll_interval))
|
||||
|
||||
try:
|
||||
# Add two servers
|
||||
server_key1 = "servers/context/server1"
|
||||
server_url1 = "http://localhost:8001"
|
||||
mock_metadata_server.add_server(server_key1, {"url": server_url1})
|
||||
|
||||
server_key2 = "servers/context/server2"
|
||||
server_url2 = "http://localhost:8002"
|
||||
mock_metadata_server.add_server(server_key2, {"url": server_url2})
|
||||
|
||||
# Wait for the polling interval to pass
|
||||
wait_time = poll_interval * 1.5
|
||||
logger.info(
|
||||
f"Waiting {wait_time} seconds for servers to be detected...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Fetch servers
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(servers) == 2, "Should have two servers after waiting"
|
||||
|
||||
# Remove all servers
|
||||
mock_metadata_server.remove(server_key1)
|
||||
mock_metadata_server.remove(server_key2)
|
||||
|
||||
# Wait for the polling interval to pass
|
||||
logger.info(
|
||||
f"Waiting {wait_time} seconds for server removals to be detected..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Test handling of no servers
|
||||
servers = router.fetch_live_servers()
|
||||
assert len(
|
||||
servers
|
||||
) == 0, "Should have no servers after removing all and waiting"
|
||||
|
||||
# Test get_next_server with no servers should raise ValueError
|
||||
with pytest.raises(ValueError):
|
||||
router.get_next_server()
|
||||
finally:
|
||||
# Clean up
|
||||
asyncio.run(router.stop_server_monitoring())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user