mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[https://nvbugs/5826689][fix] replace etcd3 with etcd-sdk-python (#10886)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
c659280445
commit
6c4e0c3dbe
@ -64,7 +64,6 @@ nvtx
|
||||
matplotlib # FIXME: this is added to make nvtx happy
|
||||
meson
|
||||
ninja
|
||||
etcd3 @ git+https://github.com/kragniz/python-etcd3.git@e58a899579ba416449c4e225b61f039457c8072a
|
||||
blake3
|
||||
soundfile
|
||||
triton==3.5.1 # NOTE: if you update this, you must also run scripts/vendor_triton_kernels.py to vendor the new version of triton_kernels
|
||||
@ -83,3 +82,4 @@ cuda-core
|
||||
llist
|
||||
cuda-tile>=1.0.1
|
||||
nvidia-cuda-tileiras>=13.1
|
||||
etcd-sdk-python==0.0.7
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import importlib.metadata as importlib_metadata
|
||||
import importlib.util
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import etcd3
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
@ -15,6 +18,45 @@ from pydantic import BaseModel
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
def _find_module_file_in_distribution(dist, module_name: str):
|
||||
module_path = module_name.replace(".", "/")
|
||||
candidates = (f"{module_path}/__init__.py", f"{module_path}.py")
|
||||
for dist_file in dist.files or []:
|
||||
dist_file_str = str(dist_file)
|
||||
if dist_file_str.endswith(candidates):
|
||||
return dist.locate_file(dist_file)
|
||||
return None
|
||||
|
||||
|
||||
def load_module_from_distribution(dist_name: str, module_name: str):
|
||||
dist = importlib_metadata.distribution(dist_name)
|
||||
|
||||
module_file = _find_module_file_in_distribution(dist, module_name)
|
||||
if not module_file:
|
||||
raise ModuleNotFoundError(
|
||||
f"{module_name} not found in distribution {dist_name}")
|
||||
|
||||
load_name = module_name
|
||||
spec = importlib.util.spec_from_file_location(load_name, module_file)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(
|
||||
f"Could not create a module spec for {module_name} in {dist_name}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[load_name] = module
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
except Exception:
|
||||
sys.modules.pop(load_name, None)
|
||||
raise
|
||||
return module
|
||||
|
||||
|
||||
# pyectd and etcd-sdk-python both have package name "pyetcd", we need to find the correct one
|
||||
# by distribution name
|
||||
etcd3 = load_module_from_distribution("etcd-sdk-python", "pyetcd")
|
||||
|
||||
|
||||
class StorageItem(BaseModel):
|
||||
key: str
|
||||
value: Optional[str] = ""
|
||||
@ -455,14 +497,25 @@ def handle_etcd_error(return_on_error: bool | None):
|
||||
|
||||
class Etcd3ClusterStorage(ClusterStorage):
|
||||
|
||||
@staticmethod
|
||||
def _connect(cluster_uri: str) -> etcd3.MultiEndpointEtcd3Client:
|
||||
logger.info(f"Connecting to {cluster_uri}")
|
||||
endpoints = []
|
||||
for url in cluster_uri.split(","):
|
||||
parsed_url = urlparse(url)
|
||||
endpoints.append(
|
||||
etcd3.Endpoint(parsed_url.hostname,
|
||||
parsed_url.port,
|
||||
secure=False))
|
||||
return etcd3.MultiEndpointEtcd3Client(endpoints, timeout=10)
|
||||
|
||||
def __init__(self,
|
||||
cluster_uri: str,
|
||||
cluster_name: str,
|
||||
one_single_lease: bool = False,
|
||||
**kwargs):
|
||||
cluster_uri = cluster_uri.replace("etcd://", "")
|
||||
self._host, self._port = cluster_uri.rsplit(":", 1)
|
||||
self._client = etcd3.client(self._host, self._port)
|
||||
self._cluster_uri = cluster_uri
|
||||
self._client = self._connect(self._cluster_uri)
|
||||
self._leases = {}
|
||||
self._instance_lease = None
|
||||
self._watch_handles = {}
|
||||
@ -486,8 +539,7 @@ class Etcd3ClusterStorage(ClusterStorage):
|
||||
return self._client
|
||||
|
||||
def reconnect(self):
|
||||
logger.info(f"Reconnecting to {self._host}:{self._port}")
|
||||
self._client = etcd3.client(self._host, self._port)
|
||||
self._client = self._connect(self._cluster_uri)
|
||||
|
||||
async def start(self):
|
||||
# nothing to do
|
||||
|
||||
@ -3,12 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
try:
|
||||
import etcd3
|
||||
except Exception as e:
|
||||
logger.warning(f"etcd3 is not installed correctly: {e}")
|
||||
from tensorrt_llm.serve.cluster_storage import etcd3
|
||||
|
||||
|
||||
class RemoteDictionary(ABC):
|
||||
|
||||
@ -5,9 +5,7 @@ import subprocess
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import etcd3
|
||||
|
||||
from tensorrt_llm.serve.metadata_server import EtcdDictionary
|
||||
from tensorrt_llm.serve.metadata_server import EtcdDictionary, etcd3
|
||||
|
||||
|
||||
def wait_for_port(host, port, timeout=15):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user