TensorRT-LLMs/tensorrt_llm/serve/cluster_storage.py
Lizhi Zhou f02782a6f2
[https://nvbugs/5726066][fix] fix auto-scaling related failures (#9845)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
Co-authored-by: Emma Qiao <qqiao@nvidia.com>
2025-12-18 16:37:48 -05:00

559 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import abc
import asyncio
import time
from dataclasses import dataclass
from enum import IntEnum
from functools import wraps
from typing import Callable, Dict, List, Optional
import aiohttp
import etcd3
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from tensorrt_llm.logger import logger
class StorageItem(BaseModel):
key: str
value: Optional[str] = ""
expire_time: Optional[int] = -1
ttl: Optional[int] = -1
overwrite_if_exists: Optional[bool] = False
class WatchEventType(IntEnum):
SET = 0
DELETE = 1
@dataclass
class WatchEvent:
storage_item: StorageItem
event_type: WatchEventType
class WatchEventQueue:
def __init__(self, key_prefixes: List[str]):
self.key_prefixes = key_prefixes
self.events = asyncio.Queue()
async def drain(self):
events = []
event = await self.events.get()
events.append(event)
while not self.events.empty():
event = self.events.get_nowait()
events.append(event)
self.events.task_done()
return events
async def add_events(self, events: List[WatchEvent]):
loop = asyncio.get_event_loop()
for event in events:
self.events.put_nowait(event)
loop._write_to_self()
class ClusterStorage(abc.ABC):
def __init__(self, cluster_uri: str, cluster_name: str):
...
# start the storage, if it's already started, do nothing
async def start(self):
...
# stop the storage, if it's already stopped, do nothing
async def stop(self):
...
# set the key with the value, if the key already exists and overwrite_if_exists is False, return False
async def set(self,
key: str,
value: str,
overwrite_if_exists=False,
ttl: int = -1) -> bool:
...
# refresh the keys ttl
async def expire(self, key: str, ttl: int) -> bool:
...
# get the value of the key, return None if the key does not exist or is expired
async def get(self, key: str) -> str:
...
# delete the key, return True if the key is deleted, False otherwise
async def delete(self, key: str) -> bool:
...
# watch the key prefix, return the watch event queue
async def watch(self, key_prefix: str) -> WatchEventQueue:
...
# unwatch the key prefix, if the key prefix is not in the watch list, raise a KeyError
async def unwatch(self, key_prefix: str) -> None:
...
# get the value of the key prefix, return the dict of key and value
# if keys_only is True, the value will be empty string
async def get_prefix(self,
key_prefix: str,
keys_only: bool = False) -> Dict[str, str]:
...
def create_cluster_storage(cluster_uri, cluster_name, **kwargs):
if cluster_uri.startswith("http://") or cluster_uri.startswith("https://"):
return HttpClusterStorageServer(cluster_uri, cluster_name, **kwargs)
elif cluster_uri.startswith("etcd://"):
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
def create_cluster_storage_client(cluster_uri, cluster_name, **kwargs):
if cluster_uri.startswith("http://") or cluster_uri.startswith("https://"):
return HttpClusterStorageClient(cluster_uri, cluster_name, **kwargs)
elif cluster_uri.startswith("etcd://"):
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
# All Http endpoints return {"result": <result>} and status code 400
# if result is False or None, 200 otherwise
def jsonify(f):
@wraps(f)
async def wrapper(*args, **kwargs):
result = await f(*args, **kwargs)
return JSONResponse({"result": result},
status_code=200 if result else 400)
return wrapper
def key_time():
return time.monotonic()
class HttpClusterStorageServer(ClusterStorage):
def __init__(self,
cluster_uri,
cluster_name,
server: FastAPI = None,
**kwargs):
self._storage = {}
self._lock = asyncio.Lock()
self._watch_handles = {}
self._watch_lock = asyncio.Lock()
self._check_expired_task = None
self._check_expired_interval = 1 # in seconds
if server:
self.add_routes(server)
def add_routes(self, server: FastAPI):
server.add_api_route("/set", jsonify(self._set), methods=["POST"])
server.add_api_route("/get", jsonify(self.get), methods=["GET"])
server.add_api_route("/delete",
jsonify(self.delete),
methods=["DELETE"])
server.add_api_route("/expire", jsonify(self.expire), methods=["GET"])
server.add_api_route("/get_prefix",
jsonify(self.get_prefix),
methods=["GET"])
async def start(self):
if self._check_expired_task:
return
self._check_expired_task = asyncio.create_task(self._check_expired())
async def stop(self):
if self._check_expired_task:
self._check_expired_task.cancel()
self._check_expired_task = None
async def set(self,
key: str,
value: str,
overwrite_if_exists: bool = False,
ttl: int = -1) -> bool:
storage_item = StorageItem(key=key,
value=value,
overwrite_if_exists=overwrite_if_exists,
ttl=ttl)
return await self._set(storage_item)
async def _set(self, storage_item: StorageItem) -> bool:
async with self._lock:
if storage_item.key in self._storage and not storage_item.overwrite_if_exists:
return False
if storage_item.expire_time < 0 and storage_item.ttl and storage_item.ttl > 0:
storage_item.expire_time = key_time() + storage_item.ttl
self._storage[storage_item.key] = storage_item
await self._notify_watch_event(storage_item.key, storage_item,
WatchEventType.SET)
return True
async def get(self, key: str) -> str:
async with self._lock:
if key in self._storage:
item = self._storage[key]
if item.expire_time < 0 or item.expire_time > key_time():
return item.value
else:
await self._notify_watch_event(key, item,
WatchEventType.DELETE)
self._storage.pop(key)
return None
async def expire(self, key: str, ttl: int) -> bool:
async with self._lock:
if key in self._storage:
self._storage[key].expire_time = key_time() + int(ttl)
return True
return False
async def delete(self, key: str) -> bool:
async with self._lock:
if key in self._storage:
storage_item = self._storage[key]
await self._notify_watch_event(key, storage_item,
WatchEventType.DELETE)
self._storage.pop(key)
return True
return False
async def get_prefix(self,
key_prefix: str,
keys_only: bool = False) -> List[str]:
async with self._lock:
return {
k: "" if keys_only else v.value
for k, v in self._storage.items() if k.startswith(key_prefix)
}
async def watch(self, key_prefix: str) -> WatchEventQueue:
async with self._watch_lock:
if key_prefix in self._watch_handles:
logger.debug(
f"Watch handle for key prefix {key_prefix} already exists, skip"
)
else:
self._watch_handles[key_prefix] = WatchEventQueue(
key_prefixes=[key_prefix])
return self._watch_handles[key_prefix]
async def unwatch(self, key_prefix: str) -> None:
async with self._watch_lock:
if key_prefix in self._watch_handles:
self._watch_handles.pop(key_prefix)
else:
raise KeyError(
f"Key prefix {key_prefix} not in watch list, {self._watch_handles.keys()}"
)
async def _notify_watch_event(self, key, storage_item: StorageItem,
event_type: WatchEventType):
loop = asyncio.get_event_loop()
async with self._watch_lock:
for watch_key, handle in self._watch_handles.items():
if key.startswith(watch_key):
# update queue immediately and wake up the event loop
handle.events.put_nowait(
WatchEvent(storage_item, event_type))
logger.info(
f"Notifying watch event for watch key {watch_key} with type {event_type}"
)
loop._write_to_self()
logger.info(
f"Notified watch event for key {key} with type {event_type}")
async def _check_expired(self):
while True:
await asyncio.sleep(self._check_expired_interval)
try:
before_len = len(self._storage)
current_time = key_time()
async with self._lock:
kv_to_delete = {
k: v
for k, v in self._storage.items()
if v.expire_time > 0 and v.expire_time < current_time
}
for k in kv_to_delete.keys():
self._storage.pop(k)
for k, v in kv_to_delete.items():
await self._notify_watch_event(k, v, WatchEventType.DELETE)
if len(kv_to_delete) > 0:
logger.debug(
f"Checked expired, {before_len} -> {len(self._storage)}, keys to delete: {kv_to_delete.keys()}"
)
except Exception as e:
logger.error(f"Error checking expired: {e}")
class HttpClusterStorageClient(ClusterStorage):
def __init__(self, cluster_uri, cluster_name, **kwargs):
self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
total=5))
self._cluster_uri = cluster_uri if cluster_uri.startswith(
"http") else f"http://{cluster_uri}"
self._cluster_name = cluster_name
def __del__(self):
try:
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self._session.close(),
asyncio.get_event_loop())
except RuntimeError:
pass
def _url_for(self, endpoint: str) -> str:
return f"{self._cluster_uri}/{endpoint}"
async def _post_json(self,
endpoint: str,
data: StorageItem,
headers: dict = {},
ignore_result: bool = False) -> bool:
headers["Content-Type"] = "application/json"
try:
async with self._session.post(self._url_for(endpoint),
headers=headers,
json=data.model_dump()) as resp:
if resp.status == 200:
json = await resp.json()
return json.get("result") if not ignore_result else True
return None
except (aiohttp.ClientError, OSError) as e:
logger.warning(f"Failed to post {endpoint}, error: {e}")
return False
async def _get(self,
endpoint: str,
ignore_result: bool = False,
**kwargs) -> bool:
try:
async with self._session.get(self._url_for(endpoint),
params=kwargs) as resp:
if resp.status == 200:
json = await resp.json()
return json.get("result") if not ignore_result else True
return None if not ignore_result else False
except (aiohttp.ClientError, OSError) as e:
logger.warning(f"Failed to get {endpoint}, error: {e}")
return False
async def set(self,
key: str,
value: str,
overwrite_if_exists: bool = False,
ttl: int = -1) -> bool:
storage_item = StorageItem(key=key,
value=value,
overwrite_if_exists=overwrite_if_exists,
ttl=ttl)
return await self._post_json("set", storage_item, ignore_result=True)
async def expire(self, key: str, ttl: int) -> bool:
return await self._get("expire",
key=key,
ttl=str(ttl),
ignore_result=True)
async def get(self, key: str) -> str:
return await self._get("get", key=key)
async def get_prefix(self,
key_prefix: str,
keys_only: bool = False) -> Dict[str, str]:
return await self._get("get_prefix",
key_prefix=key_prefix,
keys_only=int(keys_only))
async def delete(self, key: str) -> bool:
try:
async with self._session.delete(self._url_for("delete"),
params={"key": key}) as resp:
return resp.status == 200
except (aiohttp.ClientError, OSError) as e:
logger.warning(f"Failed to delete key {key}, error: {e}")
return False
async def watch(self, key_prefix: str) -> WatchEventQueue:
raise NotImplementedError(
"Watch functionality not implemented for HTTP client")
async def unwatch(self, key_prefix: str) -> None:
raise NotImplementedError(
"Unwatch functionality not implemented for HTTP client")
class Etcd3WatchEventQueue(WatchEventQueue):
def __init__(self,
key_prefix: str,
cancel_event: Callable[[], None] = None):
self.key_prefix = key_prefix
self.events = asyncio.Queue()
self._cancel_event = cancel_event
def cancel_event(self):
if self._cancel_event:
self._cancel_event()
def set_cancel_event(self, cancel_event: Callable[[], None]):
self._cancel_event = cancel_event
def __del__(self):
self.cancel_event()
def add_events_from_resp(self, watch_resp):
try:
for event in watch_resp.events:
# Event type is not in public interface of etcd3
event_type = WatchEventType.SET if "Put" in event.__class__.__name__ else WatchEventType.DELETE
self.events.put_nowait(
WatchEvent(
storage_item=StorageItem(
key=event.key.decode("utf-8"),
value=event.value.decode("utf-8")),
event_type=event_type,
))
if self.events._loop:
self.events._loop._write_to_self()
except Exception as e:
logger.error(f"Error adding event: {e}")
self.cancel_event()
def handle_etcd_error(return_on_error: bool | None):
def decorator(f):
async def wrap(*args, **kwargs):
storage = args[0]
try:
return await f(*args, **kwargs)
except etcd3.Etcd3Exception as e:
logger.error(f"Etcd3 error in {f.__name__}: {e}")
return return_on_error
except ValueError:
logger.error(f"Etcd client value error in {f.__name__}")
storage.reconnect()
return return_on_error
return wrap
return decorator
class Etcd3ClusterStorage(ClusterStorage):
def __init__(self,
cluster_uri: str,
cluster_name: str,
one_single_lease: bool = False,
**kwargs):
cluster_uri = cluster_uri.replace("etcd://", "")
self._host, self._port = cluster_uri.rsplit(":", 1)
self._client = etcd3.client(self._host, self._port)
self._leases = {}
self._instance_lease = None
self._watch_handles = {}
self._one_single_lease = one_single_lease
def __del__(self):
self._watch_handles.clear()
self._client.close()
def _get_lease(self, key: str, ttl: int = -1) -> etcd3.Lease:
if ttl <= 0:
return None
if self._one_single_lease:
return self._instance_lease
if key not in self._leases:
self._leases[key] = self.client.lease(ttl)
return self._leases[key]
@property
def client(self):
return self._client
def reconnect(self):
logger.info(f"Reconnecting to {self._host}:{self._port}")
self._client = etcd3.client(self._host, self._port)
async def start(self):
# nothing to do
...
async def stop(self):
# nothing to do
...
@handle_etcd_error(return_on_error=False)
async def set(self,
key: str,
value: str,
overwrite_if_exists: bool = False,
ttl: int = -1) -> bool:
lease = self._get_lease(key, ttl)
if not overwrite_if_exists:
return self.client.put_if_not_exists(key, value, lease=lease)
else:
self.client.put(key, value, lease=lease)
return True
@handle_etcd_error(return_on_error=None)
async def get(self, key: str) -> str:
data, meta = self.client.get(key)
return data.decode('utf-8') if data else None
@handle_etcd_error(return_on_error=False)
async def delete(self, key: str) -> bool:
self.client.delete(key)
@handle_etcd_error(return_on_error=False)
async def expire(self, key: str, ttl: int) -> bool:
if ttl <= 0:
logger.error(f"TTL must be greater than 0, got {ttl}")
return False
lease = self._get_lease(key, ttl)
assert lease is not None, "Lease must be created"
# TTL will be ignored since it can only be set when creating a lease
next(self.client.refresh_lease(lease_id=lease.id), None)
return True
@handle_etcd_error(return_on_error={})
async def get_prefix(self,
key_prefix: str,
keys_only: bool = False) -> Dict[str, str]:
resp = self.client.get_prefix(key_prefix)
return {
metadata.key.decode("utf-8"): "" if keys_only else v.decode("utf-8")
for v, metadata in resp
}
@handle_etcd_error(return_on_error=None)
async def watch(self, key_prefix: str) -> WatchEventQueue:
if key_prefix in self._watch_handles:
return self._watch_handles[key_prefix]
watch_handle = Etcd3WatchEventQueue(key_prefix=key_prefix)
watch_id = self.client.add_watch_prefix_callback(
key_prefix, watch_handle.add_events_from_resp)
watch_handle.set_cancel_event(
lambda: self.client.cancel_watch(watch_id))
self._watch_handles[key_prefix] = watch_handle
return watch_handle
async def unwatch(self, key_prefix: str) -> None:
handle = self._watch_handles.pop(key_prefix)
if handle:
handle.cancel_event()