TensorRT-LLMs/tensorrt_llm/serve/disagg_auto_scaling.py
2025-12-16 05:16:32 -08:00

372 lines
15 KiB
Python

import asyncio
import json
import os
import random
import socket
import time
from dataclasses import asdict, dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from tensorrt_llm.llmapi.disagg_utils import DisaggClusterConfig, ServerRole
from tensorrt_llm.logger import logger
from .cluster_storage import (ClusterStorage, StorageItem, WatchEvent,
WatchEventType, key_time)
@dataclass
class WorkerInfo:
worker_id: str
host: str = ""
port: int = 0
role: ServerRole = ServerRole.CONTEXT
def get_worker_key_prefix(cluster_name: str) -> str:
return f"/trtllm-disagg/{cluster_name}/workers"
def get_worker_key(name: str, role: ServerRole, worker_id: str = "") -> str:
return f"{get_worker_key_prefix(name)}/{worker_id}"
def get_host_from_uri(uri: str) -> str:
return uri.split("://")[1].split(":")[0]
# Get the local ip address from a remote host,
# if remote host is not provided, use Google's public DNS server "8.8.8.8"
def get_local_ip(remote_host: str = "8.8.8.8") -> str:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect((remote_host, 80))
return s.getsockname()[0]
class DisaggClusterManager:
"""
The cluster manager is responsible for managing the workers in the cluster.
It will watch the workers and notify the router when the workers are changed.
"""
def __init__(self, config: DisaggClusterConfig, storage: ClusterStorage):
self._config = config
self._cluster_storage = storage
self._lock = asyncio.Lock()
self._minimal_ctx_worker_num = config.minimal_instances.context_servers
self._minimal_gen_worker_num = config.minimal_instances.generation_servers
self._current_ctx_workers = {} # worker_id -> WorkerInfo
self._current_gen_workers = {} # worker_id -> WorkerInfo
self._watch_handle = None
self._watch_task = None
def __del__(self):
try:
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self.stop(),
asyncio.get_event_loop())
except RuntimeError:
# the event loop may not be running when the cluster manager is destroyed
pass
async def start(self) -> None:
await self._cluster_storage.start()
async def stop(self) -> None:
await self.unwatch_workers()
await self._cluster_storage.stop()
async def cluster_info(self) -> Dict[str, Any]:
async with self._lock:
return {
"current_workers": {
"context_servers": [
asdict(worker)
for worker in self._current_ctx_workers.values()
],
"generation_servers": [
asdict(worker)
for worker in self._current_gen_workers.values()
]
},
"minimal_instances": {
"context_servers": self._minimal_ctx_worker_num,
"generation_servers": self._minimal_gen_worker_num
},
}
@property
def current_ctx_worker_num(self) -> int:
return len(self._current_ctx_workers)
@property
def current_gen_worker_num(self) -> int:
return len(self._current_gen_workers)
@property
def worker_key_prefix(self) -> str:
return get_worker_key_prefix(self._config.cluster_name)
async def watch_workers(
self,
get_existing_first: bool = True,
on_event: Optional[Callable[[WorkerInfo, WatchEventType],
Awaitable[Any]]] = None):
if self._watch_handle:
logger.error("Watch handle is already initialized")
return []
workers = []
self._watch_handle = await self._cluster_storage.watch(
self.worker_key_prefix)
if get_existing_first:
# There is a tiny gap between getting existing workers and watching the key,
# which may cause we missing some workers registered in between.
resp = await self._cluster_storage.get_prefix(
self.worker_key_prefix, keys_only=False)
events = []
for worker_id, data in resp.items():
event = WatchEvent(storage_item=StorageItem(key=worker_id,
value=data),
event_type=WatchEventType.SET)
workers.append(self._parse_worker_info(event))
events.append(event)
await self._watch_handle.add_events(events)
self._watch_handle = await self._cluster_storage.watch(
self.worker_key_prefix)
async def worker_event_loop():
logger.info(
f"Start watching worker events with {len(workers)} existing workers"
)
for worker_info in workers:
await on_event(worker_info, WatchEventType.SET)
while True:
try:
worker_events = await self._watch_handle.drain()
for event in worker_events:
worker_info = self._parse_worker_info(event)
await on_event(worker_info, event.event_type)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(
f"Error updating routers by worker events: {e}")
await asyncio.sleep(1)
logger.info("Stop watching worker events")
if on_event:
self._watch_task = asyncio.create_task(worker_event_loop())
return workers
async def unwatch_workers(self) -> None:
if self._watch_handle:
await self._cluster_storage.unwatch(self.worker_key_prefix)
self._watch_handle = None
if self._watch_task:
self._watch_task.cancel()
self._watch_task = None
async def get_worker_events(
self) -> List[Tuple[WorkerInfo, WatchEventType]]:
if self._watch_handle is None:
raise ValueError("Watch handle is not initialized")
events = await self._watch_handle.drain()
worker_events = []
for event in events:
try:
worker_info = self._parse_worker_info(event)
worker_events.append((worker_info, event.event_type))
except Exception as e:
logger.error(
f"Failed to parse worker info: {event.storage_item.value}, error: {e}"
)
continue
return worker_events
def _log_cluster_status(self, worker_info: WorkerInfo, change_event: str):
logger.info(
f"Worker {worker_info.worker_id} becomes {change_event}, current context worker: {self.current_ctx_worker_num}/{self._minimal_ctx_worker_num}, current generation worker: {self.current_gen_worker_num}/{self._minimal_gen_worker_num}"
)
def _get_workers(self, role: ServerRole) -> dict[str, WorkerInfo]:
if role == ServerRole.CONTEXT:
return self._current_ctx_workers
elif role == ServerRole.GENERATION:
return self._current_gen_workers
else:
raise ValueError(f"Invalid worker role: {role}")
def _get_workers_by_id(self, worker_id: str) -> dict[str, WorkerInfo]:
if worker_id in self._current_ctx_workers:
return self._current_ctx_workers
elif worker_id in self._current_gen_workers:
return self._current_gen_workers
else:
raise ValueError(f"Worker {worker_id} is unknown")
def _parse_worker_info(self, event: WatchEvent) -> WorkerInfo:
# parse the worker info from the event, if it's a delete event, pop the corresponding worker from the current workers
# if it's a set event, parse the worker info from the value and add it to the current workers
# return the worker info and whether to notify the event
if event.event_type == WatchEventType.DELETE:
workers = self._get_workers_by_id(event.storage_item.key)
if workers is None:
logger.warning(
f"Failed to parse delete event: Worker {event.storage_item.key} is unknown, "
)
worker_info = WorkerInfo(worker_id=event.storage_item.key)
else:
worker_info = workers.pop(event.storage_item.key)
elif event.event_type == WatchEventType.SET:
try:
worker_info = WorkerInfo(**json.loads(event.storage_item.value))
worker_info.role = ServerRole(worker_info.role)
workers = self._get_workers(worker_info.role)
workers[event.storage_item.key] = worker_info
except Exception as e:
logger.error(
f"Failed to parse set event: {event.storage_item.key}: {event.storage_item.value}, error: {e}"
)
# Generate a dummy worker info with id only, router should be able to ignore it
worker_info = WorkerInfo(worker_id=event.storage_item.key)
else:
raise ValueError(f"Invalid event type: {event.event_type}")
self._log_cluster_status(
worker_info, "active/updated"
if event.event_type == WatchEventType.SET else "inactive")
return worker_info
async def is_ready(self) -> bool:
return self.current_ctx_worker_num >= self._minimal_ctx_worker_num and self.current_gen_worker_num >= self._minimal_gen_worker_num
async def is_ready_with_router(self, router_ctx_worker_num: int,
router_gen_worker_num: int) -> bool:
return router_ctx_worker_num >= self._minimal_ctx_worker_num and router_gen_worker_num >= self._minimal_gen_worker_num
class DisaggClusterWorker:
"""
The cluster worker is responsible for registering and deregistering the worker to the cluster storage.
It will send heartbeat to the cluster storage every heartbeat_interval_sec seconds.
If the worker heartbeat fails, it will re-register itself.
"""
LOCALHOST_IPS = ["localhost", "127.0.0.1", "0.0.0.0", "::1",
"::"] # nosec B104
def __init__(self, role: ServerRole, host: str, port: int,
config: DisaggClusterConfig, storage: ClusterStorage):
self._role = role
self._port = port
self._config = config
self._cluster_storage = storage
self._stop = False
self._heartbeat_task = None
self._last_heartbeat = 0
register_host = host
# if the host is localhost and the cluster uri is not localhost, use the hostname to register the worker
disagg_host = get_host_from_uri(self._config.cluster_uri)
if host in self.LOCALHOST_IPS and disagg_host not in self.LOCALHOST_IPS:
register_host = get_local_ip(disagg_host)
self._host = register_host
self._worker_id = f"{role.name}-{register_host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
def __del__(self):
try:
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self.deregister_worker(),
asyncio.get_event_loop())
except RuntimeError:
# the event loop may not be running when the worker is destroyed
pass
@property
def worker_id(self) -> str:
return self._worker_id
@property
def worker_info(self) -> WorkerInfo:
return WorkerInfo(worker_id=self._worker_id,
role=self._role,
host=self._host,
port=self._port)
@property
def worker_key(self) -> str:
return get_worker_key(self._config.cluster_name, self._role,
self._worker_id)
async def register_worker(self, validator=None, retry_interval=5) -> bool:
self._stop = False
await self._cluster_storage.start()
if validator and not validator():
logger.warning(
f"Worker {self.worker_info.worker_id} is not valid, skipping registration"
)
return False
worker_info = self.worker_info
logger.debug(
f"Worker {self.worker_info.worker_id} registering, {asdict(worker_info)}"
)
success = await self._cluster_storage.set(
self.worker_key,
json.dumps(asdict(worker_info)),
ttl=self._config.inactive_timeout_sec)
if not success:
if retry_interval > 0:
logger.warning(
f"Worker {self.worker_info.worker_id} registration failed, retry in {retry_interval} seconds"
)
await asyncio.sleep(retry_interval)
return await self.register_worker(validator, retry_interval)
else:
logger.info(
f"Worker {self.worker_info.worker_id} registration successful")
self._last_heartbeat = key_time()
if self._config.heartbeat_interval_sec > 0 and self._config.heartbeat_interval_sec < self._config.inactive_timeout_sec:
if not self._heartbeat_task:
self._heartbeat_task = asyncio.create_task(
self._heartbeat(validator))
else:
logger.warning(
f"Heartbeat interval {self._config.heartbeat_interval_sec} is not positive or less than inactive timeout {self._config.inactive_timeout_sec}, heartbeat is disabled"
)
return True
async def deregister_worker(self) -> bool:
self._stop = True
if self._heartbeat_task:
self._heartbeat_task.cancel()
self._heartbeat_task = None
await self._cluster_storage.stop()
success = await self._cluster_storage.delete(self.worker_key)
if not success:
logger.warning(
f"Worker {self.worker_info.worker_id} deregistration failed")
return success
async def _heartbeat(self, validator=None):
logger.info(f"Worker {self.worker_info.worker_id} heartbeat started")
while not self._stop:
remaining_time = self._config.heartbeat_interval_sec - (
key_time() - self._last_heartbeat)
if remaining_time > 0:
await asyncio.sleep(remaining_time)
self._last_heartbeat = key_time()
if validator and not validator():
logger.warning(
f"Worker {self.worker_info.worker_id} is not valid, skipping heartbeat {key_time()}"
)
continue
expire_res = await self._cluster_storage.expire(
self.worker_key, self._config.inactive_timeout_sec)
if not expire_res:
logger.warning(
f"Worker {self.worker_info.worker_id} heartbeat failed, re-registering {key_time()}"
)
await self.register_worker(validator)
else:
logger.debug(
f"Worker {self.worker_info.worker_id} heartbeat successful {key_time()}"
)