TensorRT-LLMs/tensorrt_llm/serve/metadata_server.py
Shunkangz ae9a6cf24f
feat: Add integration of etcd (#3738)
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Signed-off-by: BatshevaBlack <132911331+BatshevaBlack@users.noreply.github.com>
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Batsheva Black <bblack@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: BatshevaBlack <132911331+BatshevaBlack@users.noreply.github.com>
2025-06-03 20:01:44 +08:00

96 lines
2.4 KiB
Python

import json
from abc import ABC, abstractmethod
from typing import Optional
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig
from tensorrt_llm.logger import logger
try:
import etcd3
except Exception as e:
logger.warning(f"etcd3 is not installed correctly: {e}")
class RemoteDictionary(ABC):
@abstractmethod
def get(self, key: str) -> str:
pass
@abstractmethod
def put(self, key: str, value: str):
pass
@abstractmethod
def remove(self, key: str):
pass
@abstractmethod
def keys(self) -> list[str]:
pass
class EtcdDictionary(RemoteDictionary):
def __init__(self, host: str, port: int):
self._client = etcd3.client(host, port)
def get(self, key: str) -> str:
return self._client.get(key)
def put(self, key: str, value: str):
self._client.put(key, value)
def remove(self, key: str):
self._client.delete(key)
def keys(self) -> list[str]:
# TODO: Confirm the final save key format
# This implementation assumes that key is in the
# format of "trtllm/executor_name/key"
unique_keys = set()
for _, metadata in self._client.get_all():
key = metadata.key.decode('utf-8')
sub_keys = key.split('/')
if len(sub_keys) >= 2:
top_prefix = "/".join(sub_keys[:2])
unique_keys.add(top_prefix)
return list(unique_keys)
class JsonDictionary:
def __init__(self, dict: RemoteDictionary):
self._dict = dict
def get(self, key: str) -> str:
bytes_data, _ = self._dict.get(key)
return json.loads(bytes_data.decode('utf-8'))
def put(self, key: str, value: str):
self._dict.put(key, json.dumps(value))
def remove(self, key: str):
self._dict.remove(key)
def keys(self) -> list[str]:
return self._dict.keys()
def create_metadata_server(
metadata_server_cfg: Optional[MetadataServerConfig]
) -> Optional[JsonDictionary]:
if metadata_server_cfg is None:
return None
if metadata_server_cfg.server_type == 'etcd':
dict = EtcdDictionary(host=metadata_server_cfg.hostname,
port=metadata_server_cfg.port)
else:
raise ValueError(
f"Unsupported metadata server type: {metadata_server_cfg.server_type}"
)
return JsonDictionary(dict)