mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Co-authored-by: Emma Qiao <qqiao@nvidia.com>
559 lines
20 KiB
Python
559 lines
20 KiB
Python
import abc
|
||
import asyncio
|
||
import time
|
||
from dataclasses import dataclass
|
||
from enum import IntEnum
|
||
from functools import wraps
|
||
from typing import Callable, Dict, List, Optional
|
||
|
||
import aiohttp
|
||
import etcd3
|
||
from fastapi import FastAPI
|
||
from fastapi.responses import JSONResponse
|
||
from pydantic import BaseModel
|
||
|
||
from tensorrt_llm.logger import logger
|
||
|
||
|
||
class StorageItem(BaseModel):
|
||
key: str
|
||
value: Optional[str] = ""
|
||
expire_time: Optional[int] = -1
|
||
ttl: Optional[int] = -1
|
||
overwrite_if_exists: Optional[bool] = False
|
||
|
||
|
||
class WatchEventType(IntEnum):
|
||
SET = 0
|
||
DELETE = 1
|
||
|
||
|
||
@dataclass
|
||
class WatchEvent:
|
||
storage_item: StorageItem
|
||
event_type: WatchEventType
|
||
|
||
|
||
class WatchEventQueue:
|
||
|
||
def __init__(self, key_prefixes: List[str]):
|
||
self.key_prefixes = key_prefixes
|
||
self.events = asyncio.Queue()
|
||
|
||
async def drain(self):
|
||
events = []
|
||
event = await self.events.get()
|
||
events.append(event)
|
||
while not self.events.empty():
|
||
event = self.events.get_nowait()
|
||
events.append(event)
|
||
self.events.task_done()
|
||
return events
|
||
|
||
async def add_events(self, events: List[WatchEvent]):
|
||
loop = asyncio.get_event_loop()
|
||
for event in events:
|
||
self.events.put_nowait(event)
|
||
loop._write_to_self()
|
||
|
||
|
||
class ClusterStorage(abc.ABC):
|
||
|
||
def __init__(self, cluster_uri: str, cluster_name: str):
|
||
...
|
||
|
||
# start the storage, if it's already started, do nothing
|
||
async def start(self):
|
||
...
|
||
|
||
# stop the storage, if it's already stopped, do nothing
|
||
async def stop(self):
|
||
...
|
||
|
||
# set the key with the value, if the key already exists and overwrite_if_exists is False, return False
|
||
async def set(self,
|
||
key: str,
|
||
value: str,
|
||
overwrite_if_exists=False,
|
||
ttl: int = -1) -> bool:
|
||
...
|
||
|
||
# refresh the key’s ttl
|
||
async def expire(self, key: str, ttl: int) -> bool:
|
||
...
|
||
|
||
# get the value of the key, return None if the key does not exist or is expired
|
||
async def get(self, key: str) -> str:
|
||
...
|
||
|
||
# delete the key, return True if the key is deleted, False otherwise
|
||
async def delete(self, key: str) -> bool:
|
||
...
|
||
|
||
# watch the key prefix, return the watch event queue
|
||
async def watch(self, key_prefix: str) -> WatchEventQueue:
|
||
...
|
||
|
||
# unwatch the key prefix, if the key prefix is not in the watch list, raise a KeyError
|
||
async def unwatch(self, key_prefix: str) -> None:
|
||
...
|
||
|
||
# get the value of the key prefix, return the dict of key and value
|
||
# if keys_only is True, the value will be empty string
|
||
async def get_prefix(self,
|
||
key_prefix: str,
|
||
keys_only: bool = False) -> Dict[str, str]:
|
||
...
|
||
|
||
|
||
def create_cluster_storage(cluster_uri, cluster_name, **kwargs):
|
||
if cluster_uri.startswith("http://") or cluster_uri.startswith("https://"):
|
||
return HttpClusterStorageServer(cluster_uri, cluster_name, **kwargs)
|
||
elif cluster_uri.startswith("etcd://"):
|
||
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
|
||
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
|
||
|
||
|
||
def create_cluster_storage_client(cluster_uri, cluster_name, **kwargs):
|
||
if cluster_uri.startswith("http://") or cluster_uri.startswith("https://"):
|
||
return HttpClusterStorageClient(cluster_uri, cluster_name, **kwargs)
|
||
elif cluster_uri.startswith("etcd://"):
|
||
return Etcd3ClusterStorage(cluster_uri, cluster_name, **kwargs)
|
||
raise ValueError(f"Invalid cluster storage URI: {cluster_uri}")
|
||
|
||
|
||
# All Http endpoints return {"result": <result>} and status code 400
|
||
# if result is False or None, 200 otherwise
|
||
def jsonify(f):
|
||
|
||
@wraps(f)
|
||
async def wrapper(*args, **kwargs):
|
||
result = await f(*args, **kwargs)
|
||
return JSONResponse({"result": result},
|
||
status_code=200 if result else 400)
|
||
|
||
return wrapper
|
||
|
||
|
||
def key_time():
|
||
return time.monotonic()
|
||
|
||
|
||
class HttpClusterStorageServer(ClusterStorage):
|
||
|
||
def __init__(self,
|
||
cluster_uri,
|
||
cluster_name,
|
||
server: FastAPI = None,
|
||
**kwargs):
|
||
self._storage = {}
|
||
self._lock = asyncio.Lock()
|
||
self._watch_handles = {}
|
||
self._watch_lock = asyncio.Lock()
|
||
self._check_expired_task = None
|
||
self._check_expired_interval = 1 # in seconds
|
||
if server:
|
||
self.add_routes(server)
|
||
|
||
def add_routes(self, server: FastAPI):
|
||
server.add_api_route("/set", jsonify(self._set), methods=["POST"])
|
||
server.add_api_route("/get", jsonify(self.get), methods=["GET"])
|
||
server.add_api_route("/delete",
|
||
jsonify(self.delete),
|
||
methods=["DELETE"])
|
||
server.add_api_route("/expire", jsonify(self.expire), methods=["GET"])
|
||
server.add_api_route("/get_prefix",
|
||
jsonify(self.get_prefix),
|
||
methods=["GET"])
|
||
|
||
async def start(self):
|
||
if self._check_expired_task:
|
||
return
|
||
self._check_expired_task = asyncio.create_task(self._check_expired())
|
||
|
||
async def stop(self):
|
||
if self._check_expired_task:
|
||
self._check_expired_task.cancel()
|
||
self._check_expired_task = None
|
||
|
||
async def set(self,
|
||
key: str,
|
||
value: str,
|
||
overwrite_if_exists: bool = False,
|
||
ttl: int = -1) -> bool:
|
||
storage_item = StorageItem(key=key,
|
||
value=value,
|
||
overwrite_if_exists=overwrite_if_exists,
|
||
ttl=ttl)
|
||
return await self._set(storage_item)
|
||
|
||
async def _set(self, storage_item: StorageItem) -> bool:
|
||
async with self._lock:
|
||
if storage_item.key in self._storage and not storage_item.overwrite_if_exists:
|
||
return False
|
||
if storage_item.expire_time < 0 and storage_item.ttl and storage_item.ttl > 0:
|
||
storage_item.expire_time = key_time() + storage_item.ttl
|
||
self._storage[storage_item.key] = storage_item
|
||
await self._notify_watch_event(storage_item.key, storage_item,
|
||
WatchEventType.SET)
|
||
return True
|
||
|
||
async def get(self, key: str) -> str:
|
||
async with self._lock:
|
||
if key in self._storage:
|
||
item = self._storage[key]
|
||
if item.expire_time < 0 or item.expire_time > key_time():
|
||
return item.value
|
||
else:
|
||
await self._notify_watch_event(key, item,
|
||
WatchEventType.DELETE)
|
||
self._storage.pop(key)
|
||
return None
|
||
|
||
async def expire(self, key: str, ttl: int) -> bool:
|
||
async with self._lock:
|
||
if key in self._storage:
|
||
self._storage[key].expire_time = key_time() + int(ttl)
|
||
return True
|
||
return False
|
||
|
||
async def delete(self, key: str) -> bool:
|
||
async with self._lock:
|
||
if key in self._storage:
|
||
storage_item = self._storage[key]
|
||
await self._notify_watch_event(key, storage_item,
|
||
WatchEventType.DELETE)
|
||
self._storage.pop(key)
|
||
return True
|
||
return False
|
||
|
||
async def get_prefix(self,
|
||
key_prefix: str,
|
||
keys_only: bool = False) -> List[str]:
|
||
async with self._lock:
|
||
return {
|
||
k: "" if keys_only else v.value
|
||
for k, v in self._storage.items() if k.startswith(key_prefix)
|
||
}
|
||
|
||
async def watch(self, key_prefix: str) -> WatchEventQueue:
|
||
async with self._watch_lock:
|
||
if key_prefix in self._watch_handles:
|
||
logger.debug(
|
||
f"Watch handle for key prefix {key_prefix} already exists, skip"
|
||
)
|
||
else:
|
||
self._watch_handles[key_prefix] = WatchEventQueue(
|
||
key_prefixes=[key_prefix])
|
||
return self._watch_handles[key_prefix]
|
||
|
||
async def unwatch(self, key_prefix: str) -> None:
|
||
async with self._watch_lock:
|
||
if key_prefix in self._watch_handles:
|
||
self._watch_handles.pop(key_prefix)
|
||
else:
|
||
raise KeyError(
|
||
f"Key prefix {key_prefix} not in watch list, {self._watch_handles.keys()}"
|
||
)
|
||
|
||
async def _notify_watch_event(self, key, storage_item: StorageItem,
|
||
event_type: WatchEventType):
|
||
loop = asyncio.get_event_loop()
|
||
async with self._watch_lock:
|
||
for watch_key, handle in self._watch_handles.items():
|
||
if key.startswith(watch_key):
|
||
# update queue immediately and wake up the event loop
|
||
handle.events.put_nowait(
|
||
WatchEvent(storage_item, event_type))
|
||
logger.info(
|
||
f"Notifying watch event for watch key {watch_key} with type {event_type}"
|
||
)
|
||
loop._write_to_self()
|
||
logger.info(
|
||
f"Notified watch event for key {key} with type {event_type}")
|
||
|
||
async def _check_expired(self):
|
||
while True:
|
||
await asyncio.sleep(self._check_expired_interval)
|
||
try:
|
||
before_len = len(self._storage)
|
||
current_time = key_time()
|
||
async with self._lock:
|
||
kv_to_delete = {
|
||
k: v
|
||
for k, v in self._storage.items()
|
||
if v.expire_time > 0 and v.expire_time < current_time
|
||
}
|
||
for k in kv_to_delete.keys():
|
||
self._storage.pop(k)
|
||
for k, v in kv_to_delete.items():
|
||
await self._notify_watch_event(k, v, WatchEventType.DELETE)
|
||
if len(kv_to_delete) > 0:
|
||
logger.debug(
|
||
f"Checked expired, {before_len} -> {len(self._storage)}, keys to delete: {kv_to_delete.keys()}"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error checking expired: {e}")
|
||
|
||
|
||
class HttpClusterStorageClient(ClusterStorage):
|
||
|
||
def __init__(self, cluster_uri, cluster_name, **kwargs):
|
||
self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
|
||
total=5))
|
||
self._cluster_uri = cluster_uri if cluster_uri.startswith(
|
||
"http") else f"http://{cluster_uri}"
|
||
self._cluster_name = cluster_name
|
||
|
||
def __del__(self):
|
||
try:
|
||
if asyncio.get_event_loop():
|
||
asyncio.run_coroutine_threadsafe(self._session.close(),
|
||
asyncio.get_event_loop())
|
||
except RuntimeError:
|
||
pass
|
||
|
||
def _url_for(self, endpoint: str) -> str:
|
||
return f"{self._cluster_uri}/{endpoint}"
|
||
|
||
async def _post_json(self,
|
||
endpoint: str,
|
||
data: StorageItem,
|
||
headers: dict = {},
|
||
ignore_result: bool = False) -> bool:
|
||
headers["Content-Type"] = "application/json"
|
||
try:
|
||
async with self._session.post(self._url_for(endpoint),
|
||
headers=headers,
|
||
json=data.model_dump()) as resp:
|
||
if resp.status == 200:
|
||
json = await resp.json()
|
||
return json.get("result") if not ignore_result else True
|
||
return None
|
||
except (aiohttp.ClientError, OSError) as e:
|
||
logger.warning(f"Failed to post {endpoint}, error: {e}")
|
||
return False
|
||
|
||
async def _get(self,
|
||
endpoint: str,
|
||
ignore_result: bool = False,
|
||
**kwargs) -> bool:
|
||
try:
|
||
async with self._session.get(self._url_for(endpoint),
|
||
params=kwargs) as resp:
|
||
if resp.status == 200:
|
||
json = await resp.json()
|
||
return json.get("result") if not ignore_result else True
|
||
return None if not ignore_result else False
|
||
except (aiohttp.ClientError, OSError) as e:
|
||
logger.warning(f"Failed to get {endpoint}, error: {e}")
|
||
return False
|
||
|
||
async def set(self,
|
||
key: str,
|
||
value: str,
|
||
overwrite_if_exists: bool = False,
|
||
ttl: int = -1) -> bool:
|
||
storage_item = StorageItem(key=key,
|
||
value=value,
|
||
overwrite_if_exists=overwrite_if_exists,
|
||
ttl=ttl)
|
||
return await self._post_json("set", storage_item, ignore_result=True)
|
||
|
||
async def expire(self, key: str, ttl: int) -> bool:
|
||
return await self._get("expire",
|
||
key=key,
|
||
ttl=str(ttl),
|
||
ignore_result=True)
|
||
|
||
async def get(self, key: str) -> str:
|
||
return await self._get("get", key=key)
|
||
|
||
async def get_prefix(self,
|
||
key_prefix: str,
|
||
keys_only: bool = False) -> Dict[str, str]:
|
||
return await self._get("get_prefix",
|
||
key_prefix=key_prefix,
|
||
keys_only=int(keys_only))
|
||
|
||
async def delete(self, key: str) -> bool:
|
||
try:
|
||
async with self._session.delete(self._url_for("delete"),
|
||
params={"key": key}) as resp:
|
||
return resp.status == 200
|
||
except (aiohttp.ClientError, OSError) as e:
|
||
logger.warning(f"Failed to delete key {key}, error: {e}")
|
||
return False
|
||
|
||
async def watch(self, key_prefix: str) -> WatchEventQueue:
|
||
raise NotImplementedError(
|
||
"Watch functionality not implemented for HTTP client")
|
||
|
||
async def unwatch(self, key_prefix: str) -> None:
|
||
raise NotImplementedError(
|
||
"Unwatch functionality not implemented for HTTP client")
|
||
|
||
|
||
class Etcd3WatchEventQueue(WatchEventQueue):
|
||
|
||
def __init__(self,
|
||
key_prefix: str,
|
||
cancel_event: Callable[[], None] = None):
|
||
self.key_prefix = key_prefix
|
||
self.events = asyncio.Queue()
|
||
self._cancel_event = cancel_event
|
||
|
||
def cancel_event(self):
|
||
if self._cancel_event:
|
||
self._cancel_event()
|
||
|
||
def set_cancel_event(self, cancel_event: Callable[[], None]):
|
||
self._cancel_event = cancel_event
|
||
|
||
def __del__(self):
|
||
self.cancel_event()
|
||
|
||
def add_events_from_resp(self, watch_resp):
|
||
try:
|
||
for event in watch_resp.events:
|
||
# Event type is not in public interface of etcd3
|
||
event_type = WatchEventType.SET if "Put" in event.__class__.__name__ else WatchEventType.DELETE
|
||
self.events.put_nowait(
|
||
WatchEvent(
|
||
storage_item=StorageItem(
|
||
key=event.key.decode("utf-8"),
|
||
value=event.value.decode("utf-8")),
|
||
event_type=event_type,
|
||
))
|
||
if self.events._loop:
|
||
self.events._loop._write_to_self()
|
||
except Exception as e:
|
||
logger.error(f"Error adding event: {e}")
|
||
self.cancel_event()
|
||
|
||
|
||
def handle_etcd_error(return_on_error: bool | None):
|
||
|
||
def decorator(f):
|
||
|
||
async def wrap(*args, **kwargs):
|
||
storage = args[0]
|
||
try:
|
||
return await f(*args, **kwargs)
|
||
except etcd3.Etcd3Exception as e:
|
||
logger.error(f"Etcd3 error in {f.__name__}: {e}")
|
||
return return_on_error
|
||
except ValueError:
|
||
logger.error(f"Etcd client value error in {f.__name__}")
|
||
storage.reconnect()
|
||
return return_on_error
|
||
|
||
return wrap
|
||
|
||
return decorator
|
||
|
||
|
||
class Etcd3ClusterStorage(ClusterStorage):
|
||
|
||
def __init__(self,
|
||
cluster_uri: str,
|
||
cluster_name: str,
|
||
one_single_lease: bool = False,
|
||
**kwargs):
|
||
cluster_uri = cluster_uri.replace("etcd://", "")
|
||
self._host, self._port = cluster_uri.rsplit(":", 1)
|
||
self._client = etcd3.client(self._host, self._port)
|
||
self._leases = {}
|
||
self._instance_lease = None
|
||
self._watch_handles = {}
|
||
self._one_single_lease = one_single_lease
|
||
|
||
def __del__(self):
|
||
self._watch_handles.clear()
|
||
self._client.close()
|
||
|
||
def _get_lease(self, key: str, ttl: int = -1) -> etcd3.Lease:
|
||
if ttl <= 0:
|
||
return None
|
||
if self._one_single_lease:
|
||
return self._instance_lease
|
||
if key not in self._leases:
|
||
self._leases[key] = self.client.lease(ttl)
|
||
return self._leases[key]
|
||
|
||
@property
|
||
def client(self):
|
||
return self._client
|
||
|
||
def reconnect(self):
|
||
logger.info(f"Reconnecting to {self._host}:{self._port}")
|
||
self._client = etcd3.client(self._host, self._port)
|
||
|
||
async def start(self):
|
||
# nothing to do
|
||
...
|
||
|
||
async def stop(self):
|
||
# nothing to do
|
||
...
|
||
|
||
@handle_etcd_error(return_on_error=False)
|
||
async def set(self,
|
||
key: str,
|
||
value: str,
|
||
overwrite_if_exists: bool = False,
|
||
ttl: int = -1) -> bool:
|
||
lease = self._get_lease(key, ttl)
|
||
if not overwrite_if_exists:
|
||
return self.client.put_if_not_exists(key, value, lease=lease)
|
||
else:
|
||
self.client.put(key, value, lease=lease)
|
||
return True
|
||
|
||
@handle_etcd_error(return_on_error=None)
|
||
async def get(self, key: str) -> str:
|
||
data, meta = self.client.get(key)
|
||
return data.decode('utf-8') if data else None
|
||
|
||
@handle_etcd_error(return_on_error=False)
|
||
async def delete(self, key: str) -> bool:
|
||
self.client.delete(key)
|
||
|
||
@handle_etcd_error(return_on_error=False)
|
||
async def expire(self, key: str, ttl: int) -> bool:
|
||
if ttl <= 0:
|
||
logger.error(f"TTL must be greater than 0, got {ttl}")
|
||
return False
|
||
lease = self._get_lease(key, ttl)
|
||
assert lease is not None, "Lease must be created"
|
||
# TTL will be ignored since it can only be set when creating a lease
|
||
next(self.client.refresh_lease(lease_id=lease.id), None)
|
||
return True
|
||
|
||
@handle_etcd_error(return_on_error={})
|
||
async def get_prefix(self,
|
||
key_prefix: str,
|
||
keys_only: bool = False) -> Dict[str, str]:
|
||
resp = self.client.get_prefix(key_prefix)
|
||
return {
|
||
metadata.key.decode("utf-8"): "" if keys_only else v.decode("utf-8")
|
||
for v, metadata in resp
|
||
}
|
||
|
||
@handle_etcd_error(return_on_error=None)
|
||
async def watch(self, key_prefix: str) -> WatchEventQueue:
|
||
if key_prefix in self._watch_handles:
|
||
return self._watch_handles[key_prefix]
|
||
watch_handle = Etcd3WatchEventQueue(key_prefix=key_prefix)
|
||
watch_id = self.client.add_watch_prefix_callback(
|
||
key_prefix, watch_handle.add_events_from_resp)
|
||
watch_handle.set_cancel_event(
|
||
lambda: self.client.cancel_watch(watch_id))
|
||
self._watch_handles[key_prefix] = watch_handle
|
||
return watch_handle
|
||
|
||
async def unwatch(self, key_prefix: str) -> None:
|
||
handle = self._watch_handles.pop(key_prefix)
|
||
if handle:
|
||
handle.cancel_event()
|