[TRTLLM-7843][feat] implement disagg cluster auto-scaling (#8215)

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2025-10-22 05:25:07 +08:00 committed by GitHub
parent 9b54b3bfaf
commit 23d5280a90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 683 additions and 60 deletions

View File

@ -67,7 +67,7 @@ nvtx
matplotlib # FIXME: this is added to make nvtx happy
meson
ninja
etcd3
etcd3 @ git+https://github.com/kragniz/python-etcd3.git@e58a899579ba416449c4e225b61f039457c8072a
blake3
soundfile
triton==3.3.1; platform_machine == "x86_64"

View File

@ -21,7 +21,9 @@ from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
DynamicBatchConfig, KvCacheConfig,
SchedulerConfig)
from tensorrt_llm.llmapi.disagg_utils import (MetadataServerConfig, ServerRole,
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
MetadataServerConfig, ServerRole,
extract_disagg_cluster_config,
parse_disagg_config_file,
parse_metadata_server_config_file)
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict
@ -140,7 +142,8 @@ def launch_server(host: str,
port: int,
llm_args: dict,
metadata_server_cfg: Optional[MetadataServerConfig] = None,
server_role: Optional[ServerRole] = None):
server_role: Optional[ServerRole] = None,
disagg_cluster_config: Optional[DisaggClusterConfig] = None):
backend = llm_args["backend"]
model = llm_args["model"]
@ -161,7 +164,8 @@ def launch_server(host: str,
server = OpenAIServer(llm=llm,
model=model,
server_role=server_role,
metadata_server_cfg=metadata_server_cfg)
metadata_server_cfg=metadata_server_cfg,
disagg_cluster_config=disagg_cluster_config)
# Optionally disable GC (default: not disabled)
if os.getenv("TRTLLM_SERVER_DISABLE_GC", "0") == "1":
@ -313,6 +317,10 @@ class ChoiceWithAlias(click.Choice):
help=
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
)
@click.option("--disagg_cluster_uri",
type=str,
default=None,
help="URI of the disaggregated cluster.")
@click.option("--enable_chunked_prefill",
is_flag=True,
default=False,
@ -327,7 +335,7 @@ def serve(
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str], server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool,
enable_chunked_prefill: bool):
enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str]):
"""Running an OpenAI API compatible server
MODEL: model name | HF checkpoint path | TensorRT engine path
@ -364,14 +372,27 @@ def serve(
metadata_server_cfg = parse_metadata_server_config_file(
metadata_server_config_file)
if metadata_server_cfg is not None:
assert server_role is not None, "server_role is required when metadata_server_cfg is provided"
# Specify disagg_cluster_config in config file or through command line "--disagg_cluster_uri",
# but disagg_cluster_uri takes precedence over cluster uri in config file
disagg_cluster_config = llm_args.pop("disagg_cluster", None)
if disagg_cluster_config:
disagg_cluster_config = extract_disagg_cluster_config(
disagg_cluster_config, disagg_cluster_uri)
elif disagg_cluster_uri:
disagg_cluster_config = DisaggClusterConfig(
cluster_uri=disagg_cluster_uri)
if metadata_server_cfg is not None or disagg_cluster_config is not None:
assert (
server_role is not None
), "server_role is required when metadata_server_cfg or disagg_cluster_config is provided"
try:
server_role = ServerRole[server_role.upper()]
except ValueError:
raise ValueError(f"Invalid server role: {server_role}. " \
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
launch_server(host, port, llm_args, metadata_server_cfg, server_role)
launch_server(host, port, llm_args, metadata_server_cfg, server_role,
disagg_cluster_config)
@click.command("mm_embedding_serve")

View File

@ -1,7 +1,7 @@
import logging
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple
import yaml
from mpi4py.MPI import COMM_WORLD, Comm
@ -68,6 +68,7 @@ class DisaggServerConfig():
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
max_retries: int = 1
perf_metrics_max_requests: int = 0
disagg_cluster_config: Optional[DisaggClusterConfig] = None
@dataclass
@ -111,6 +112,7 @@ def extract_disagg_cfg(hostname: str = 'localhost',
context_servers: Optional[dict] = None,
generation_servers: Optional[dict] = None,
conditional_disagg_config: Optional[dict] = None,
disagg_cluster: Optional[dict] = None,
**kwargs: Any) -> DisaggServerConfig:
context_servers = context_servers or {}
generation_servers = generation_servers or {}
@ -131,15 +133,18 @@ def extract_disagg_cfg(hostname: str = 'localhost',
# Inherit the value from the top-level
servers[key] = value
server_configs = []
disagg_cluster_config = None
ctx_router_config = extract_router_config(context_servers)
gen_router_config = extract_router_config(generation_servers)
server_configs = extract_ctx_gen_cfgs(
type="ctx", **context_servers) + extract_ctx_gen_cfgs(
type="gen", **generation_servers)
ctx_router_config.server_role = ServerRole.CONTEXT
gen_router_config.server_role = ServerRole.GENERATION
if disagg_cluster:
disagg_cluster_config = extract_disagg_cluster_config(disagg_cluster)
else:
server_configs = extract_ctx_gen_cfgs(
type="ctx", **context_servers) + extract_ctx_gen_cfgs(
type="gen", **generation_servers)
conditional_disagg_config = ConditionalDisaggConfig(
**conditional_disagg_config) if conditional_disagg_config else None
@ -147,7 +152,8 @@ def extract_disagg_cfg(hostname: str = 'localhost',
config = DisaggServerConfig(server_configs, hostname, port,
ctx_router_config, gen_router_config,
conditional_disagg_config, max_retries,
perf_metrics_max_requests)
perf_metrics_max_requests,
disagg_cluster_config)
return config
@ -235,6 +241,33 @@ def get_server_configs_dict(
return num_workers, server_dict
def extract_disagg_cluster_config(
cluster_config_dict: Dict[str, Any],
cluster_uri: Optional[str] = None) -> DisaggClusterConfig:
"""
Build the DisaggClusterConfig from the cluster_config_dict.
Use the default value of DisaggClusterConfig and MinimalInstances if the corresponding fields are not provided.
If cluster_uri is provided, it will override the cluster_uri in the cluster_config_dict.
"""
def update_dataclass(obj, data_dict: Dict[str, Any]):
for key, value in data_dict.items():
if key not in obj.__dataclass_fields__:
raise KeyError(
f"Key {key} not found in {obj.__class__.__name__}")
if value is not None:
setattr(obj, key, value)
return obj
cluster_config_dict["minimal_instances"] = update_dataclass(
MinimalInstances(), cluster_config_dict.get("minimal_instances", {}))
cluster_config = update_dataclass(
DisaggClusterConfig(cluster_uri or cluster_config_dict["cluster_uri"]),
cluster_config_dict,
)
return cluster_config
def split_world_comm(
server_configs: List[CtxGenServerConfig]) -> Tuple[bool, int, Comm]:

View File

@ -281,9 +281,10 @@ class HttpClusterStorageServer(ClusterStorage):
self._storage.pop(k)
for k, v in kv_to_delete.items():
await self._notify_watch_event(k, v, WatchEventType.DELETE)
logger.debug(
f"Checked expired, {before_len} -> {len(self._storage)}, keys to delete: {kv_to_delete.keys()}"
)
if len(kv_to_delete) > 0:
logger.debug(
f"Checked expired, {before_len} -> {len(self._storage)}, keys to delete: {kv_to_delete.keys()}"
)
except Exception as e:
logger.error(f"Error checking expired: {e}")
@ -298,9 +299,12 @@ class HttpClusterStorageClient(ClusterStorage):
self._cluster_name = cluster_name
def __del__(self):
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self._session.close(),
asyncio.get_event_loop())
try:
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self._session.close(),
asyncio.get_event_loop())
except RuntimeError:
pass
def _url_for(self, endpoint: str) -> str:
return f"{self._cluster_uri}/{endpoint}"

View File

@ -46,9 +46,13 @@ class DisaggClusterManager:
self._watch_handle = None
def __del__(self):
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self.stop(),
asyncio.get_event_loop())
try:
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self.stop(),
asyncio.get_event_loop())
except RuntimeError:
# the event loop may not be running when the cluster manager is destroyed
pass
async def start(self) -> None:
await self._cluster_storage.start()
@ -208,9 +212,13 @@ class DisaggClusterWorker:
self._worker_id = f"{role.name}-{host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
def __del__(self):
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self.deregister_worker(),
asyncio.get_event_loop())
try:
if asyncio.get_event_loop():
asyncio.run_coroutine_threadsafe(self.deregister_worker(),
asyncio.get_event_loop())
except RuntimeError:
# the event loop may not be running when the worker is destroyed
pass
@property
def worker_id(self) -> str:

View File

@ -20,9 +20,13 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
# yapf: disable
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig,
MetadataServerConfig,
MetadataServerConfig, ServerRole,
get_ctx_gen_server_urls)
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.cluster_storage import (WatchEventType,
create_cluster_storage)
from tensorrt_llm.serve.disagg_auto_scaling import (DisaggClusterManager,
WorkerInfo)
from tensorrt_llm.serve.metadata_server import create_metadata_server
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
ChatCompletionResponse,
@ -53,7 +57,6 @@ class OpenAIDisaggServer:
self.gen_router = create_router(
config.gen_router_config, self.gen_servers, metadata_server_cfg, self.metadata_server)
self.conditional_disagg_config = config.conditional_disagg_config
self.perf_metrics_max_requests = config.perf_metrics_max_requests
if self.perf_metrics_max_requests > 0:
# record corresponding keys of context and generation servers for perf metrics
@ -82,18 +85,27 @@ class OpenAIDisaggServer:
self._metrics_task = None
self.metrics_interval_secs = metrics_interval_secs
self.disagg_cluster_config = config.disagg_cluster_config
self.disagg_cluster_storage = None
self.disagg_cluster_manager = None
self._update_worker_task = None
logger.info(f"Server max retries: {self.max_retries}")
if (len(self.gen_servers) == 0):
raise ValueError("At least one generation server must be provided")
if self.disagg_cluster_config is None:
if (len(self.gen_servers) == 0):
raise ValueError("At least one generation server must be provided")
if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") != "1" and len(self.ctx_servers) == 0:
raise ValueError("At least one context server must be provided")
if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") != "1" and len(self.ctx_servers) == 0:
raise ValueError("At least one context server must be provided")
if self.conditional_disagg_config is not None and \
not isinstance(self.gen_router, KvCacheAwareRouter):
raise ValueError("Generation router must be a KvCacheAwareRouter to enable conditional disaggregation")
if self.disagg_cluster_config and self.metadata_server:
raise ValueError("Cluster manager and metadata server cannot be used together")
# Session will be initialized in lifespan
self.session: Optional[aiohttp.ClientSession] = None
@ -104,6 +116,11 @@ class OpenAIDisaggServer:
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
timeout=aiohttp.ClientTimeout(total=req_timeout_secs))
if self.disagg_cluster_manager:
await self.disagg_cluster_manager.start()
await self.disagg_cluster_manager.watch_workers()
self._update_worker_task = asyncio.create_task(self._update_router_by_watch_events())
logger.info("Waiting for context and generation servers to be ready")
await self.wait_for_servers_ready(server_start_timeout_secs)
@ -135,6 +152,9 @@ class OpenAIDisaggServer:
pass
await self.session.close() # Ensure session cleanup
if self.disagg_cluster_manager:
self._update_worker_task.cancel()
await self.disagg_cluster_manager.stop()
self.app = FastAPI(lifespan=lifespan)
@ -145,6 +165,10 @@ class OpenAIDisaggServer:
return JSONResponse(status_code=400, content={"error": str(exc)})
self.register_routes()
if self.disagg_cluster_config:
self.disagg_cluster_storage = create_cluster_storage(self.disagg_cluster_config.cluster_uri, self.disagg_cluster_config.cluster_name, server=self.app)
self.disagg_cluster_manager = DisaggClusterManager(self.disagg_cluster_config, self.disagg_cluster_storage)
async def _increment_metric(self, key: str, amount: int = 1):
if self.metrics_interval_secs > 0:
@ -185,14 +209,24 @@ class OpenAIDisaggServer:
self.app.add_api_route("/v1/chat/completions",
self.openai_chat_completion,
methods=["POST"])
self.app.add_api_route("/cluster_info", self.cluster_info, methods=["GET"])
async def health(self) -> Response:
if not await self.is_ready():
return Response(status_code=500)
return Response(status_code=200)
async def version(self) -> JSONResponse:
ver = {"version": VERSION}
return JSONResponse(content=ver)
async def cluster_info(self) -> JSONResponse:
if self.disagg_cluster_manager:
cluster_info = await self.disagg_cluster_manager.cluster_info()
cluster_info["is_ready"] = await self.is_ready()
return JSONResponse(content=cluster_info)
return JSONResponse(content={})
async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int, raw_request: Request):
async with self.perf_metrics_keys_lock:
self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_arrival_time, raw_request.state.server_first_token_time))
@ -252,6 +286,8 @@ class OpenAIDisaggServer:
async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response:
if not await self.is_ready():
raise HTTPException(status_code=400, detail="Cluster is not ready")
try:
if not isinstance(req.prompt, str):
# Check if it's a list and contains integers
@ -266,7 +302,8 @@ class OpenAIDisaggServer:
await self._handle_exception(e)
async def openai_chat_completion(self, req: ChatCompletionRequest, raw_request: Request) -> Response:
if not await self.is_ready():
raise HTTPException(status_code=400, detail="Cluster is not ready")
try:
return await self._send_disagg_request(req, raw_request)
except Exception as e:
@ -574,5 +611,31 @@ class OpenAIDisaggServer:
raise TimeoutError("Timeout waiting for context and generation servers to be ready")
logger.info("Context and generation servers are ready")
async def is_ready(self) -> bool:
if self.disagg_cluster_manager:
return await self.disagg_cluster_manager.is_ready_with_router(len(self.ctx_router.servers), len(self.gen_router.servers))
return True
async def wait_for_servers_ready(self, server_start_timeout_secs: int = 180):
await self.wait_for_all_servers_ready(self.session, self.ctx_servers, self.gen_servers, server_start_timeout_secs)
async def _update_router_by_watch_events(self):
def worker_repr(worker_info: WorkerInfo):
return f"http://{worker_info.host}:{worker_info.port}"
router_map = {
ServerRole.CONTEXT: self.ctx_router,
ServerRole.GENERATION: self.gen_router
}
logger.info("Start updating routers by worker events")
while True:
try:
worker_events = await self.disagg_cluster_manager.get_worker_events()
for worker_info, event_type in worker_events:
if event_type == WatchEventType.SET:
await router_map[worker_info.role].add_server(worker_repr(worker_info))
elif event_type == WatchEventType.DELETE:
await router_map[worker_info.role].remove_server(worker_repr(worker_info))
logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}")
except Exception as e:
logger.error(f"Error updating routers by worker events: {e}")
await asyncio.sleep(1)

View File

@ -28,12 +28,15 @@ from tensorrt_llm.inputs.data import TokensPrompt
from tensorrt_llm.inputs.utils import ConversationMessage, apply_chat_template
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import MultimodalEncoder
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
MetadataServerConfig, ServerRole)
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.logger import logger
from tensorrt_llm.metrics.collector import MetricsCollector
from tensorrt_llm.serve.chat_utils import (check_multiple_response,
parse_chat_messages_coroutines)
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
from tensorrt_llm.serve.metadata_server import create_metadata_server
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
ChatCompletionResponse,
@ -75,12 +78,17 @@ class OpenAIServer:
llm: Union[LLM, MultimodalEncoder],
model: str,
server_role: Optional[ServerRole],
metadata_server_cfg: MetadataServerConfig):
metadata_server_cfg: MetadataServerConfig,
disagg_cluster_config: Optional[DisaggClusterConfig] = None):
self.llm = llm
self.tokenizer = llm.tokenizer
self.metadata_server = create_metadata_server(metadata_server_cfg)
self.disagg_cluster_config = disagg_cluster_config
self.server_role = server_role
self.binding_addr = None # Will be set in __call__
# Will be set in __call__
self.binding_addr = None
self.host = None
self.port = None
hf_tokenizer_path = llm._hf_model_dir or self.tokenizer.tokenizer.name_or_path
trust_remote_code = llm.args.trust_remote_code
try:
@ -129,6 +137,10 @@ class OpenAIServer:
else:
self.use_harmony = (self.model_config.model_type == "gpt_oss")
# as disagg-worker
self.disagg_cluster_storage = None
self.disagg_cluster_worker = None
@asynccontextmanager
async def lifespan(app: FastAPI):
if self.metadata_server is not None:
@ -144,12 +156,19 @@ class OpenAIServer:
self.metadata_server.put(f"trtllm/{self.llm.llm_id}", metadata)
logger.info(f"trtllm/{self.llm.llm_id} is registered")
if self.disagg_cluster_config:
self.disagg_cluster_storage = create_cluster_storage_client(self.disagg_cluster_config.cluster_uri, self.disagg_cluster_config.cluster_name)
self.disagg_cluster_worker= DisaggClusterWorker(self.server_role, self.host, self.port, self.disagg_cluster_config, self.disagg_cluster_storage)
await self.disagg_cluster_worker.register_worker()
# terminate rank0 worker
yield
if self.metadata_server is not None:
self.metadata_server.remove(f"trtllm/{self.llm.llm_id}")
logger.info(f"trtllm/{self.llm.llm_id} is unregistered")
if self.disagg_cluster_worker:
await self.disagg_cluster_worker.deregister_worker()
self.llm.shutdown()
self.app = FastAPI(lifespan=lifespan)
@ -949,6 +968,8 @@ class OpenAIServer:
async def __call__(self, host, port):
# Store the binding address for server registration
self.binding_addr = f"http://{host}:{port}"
self.host = host
self.port = port
config = uvicorn.Config(self.app,
host=host,
port=port,

View File

@ -164,6 +164,34 @@ class Router(ABC):
new_servers: The new server list
"""
@property
def servers(self) -> List[str]:
return self._servers
async def add_server(self, server: str):
if server in self._servers:
logger.warning(f"Server {server} already exists")
return
async with self._lock:
old_servers = self._servers.copy()
self._servers = [*old_servers, server]
self._on_servers_updated(old_servers, self._servers)
logger.debug(
f"Added server {server}, current server list: {self._servers}")
async def remove_server(self, server: str):
if server not in self._servers:
logger.warning(f"Server {server} does not exist")
return
async with self._lock:
old_servers = self._servers.copy()
self._servers = [
old_server for old_server in old_servers if old_server != server
]
self._on_servers_updated(old_servers, self._servers)
logger.debug(
f"Removed server {server}, current server list: {self._servers}")
@abstractmethod
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
'''Select server by request and return some intermediate information'''
@ -523,6 +551,7 @@ class KvCacheAwareRouter(Router):
super().__init__(server_role, servers, metadata_server_cfg,
metadata_server)
self._lock = asyncio.Lock()
self._use_tokens = use_tokens
# Load map between servers and their number of tokens processed
self._server_state: dict[str, KvCacheAwareServerState] = {
@ -557,7 +586,8 @@ class KvCacheAwareRouter(Router):
return [tokenizer(prompt)["input_ids"] for prompt in prompts]
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
servers = list(self._server_state.keys())
async with self._lock:
servers = list(self._server_state.keys())
token_lists = self._tokenize(request)
block_hashes: list[list[int]] = []
for token_list in token_lists:
@ -589,8 +619,8 @@ class KvCacheAwareRouter(Router):
i] / self._max_batch_size
scores.append(score)
server = servers[scores.index(max(scores))]
await self._server_state[server].increment_load(request)
async with self._lock:
await self._server_state[server].increment_load(request)
self._req_routing_table[id(request)] = server
return server, {
"block_hashes": block_hashes, # list[list[int]]
@ -604,8 +634,9 @@ class KvCacheAwareRouter(Router):
async with self._lock:
server = self._req_routing_table[id(request)]
del self._req_routing_table[id(request)]
await self._server_state[server].decrement_load(request,
session=session)
if server in self._server_state:
await self._server_state[server].decrement_load(request,
session=session)
def _on_servers_updated(self, old_servers, new_servers):
raise NotImplementedError(
@ -632,22 +663,19 @@ def create_router(router_config: Optional[RouterConfig],
Raises:
ValueError: If an unsupported router type is provided
"""
if router_config is None:
# Create a default router without server_role
return RoundRobinRouter(None, servers)
router_map = {
"round_robin": RoundRobinRouter,
"load_balancing": LoadBalancingRouter,
"kv_cache_aware": KvCacheAwareRouter,
}
router_type = router_config.type
router_type = router_config.type if router_config else "round_robin"
router_class = router_map.get(router_type.lower())
if router_class is None:
raise ValueError(f"Unsupported router type: {router_type}. "
f"Supported types are: {list(router_map.keys())}")
extra_args = router_config.args if router_config else {}
# Pass server_role as the first argument
return router_class(router_config.server_role, servers, metadata_server_cfg,
metadata_server, **router_config.args)
return router_class(router_config.server_role if router_config else None,
servers, metadata_server_cfg, metadata_server,
**extra_args)

View File

@ -0,0 +1,397 @@
import asyncio
import os
import subprocess
import tempfile
import openai
import pytest
import requests
import yaml
from tensorrt_llm.logger import logger
TEST_PORT = 18000
HEARTBEAT_INTERVAL = 1
INACTIVE_TIMEOUT = 2
ROUTER_TYPES = ["round_robin",
"load_balancing"] # kv_cache_aware doesn't support auto-scaling
@pytest.fixture
def model_name():
return "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@pytest.fixture
def disagg_cluster_config():
# same cluster config for workers and proxy server
return {
"cluster_uri": f"http://localhost:{TEST_PORT}",
"cluster_name": "test_cluster",
"heartbeat_interval_sec": HEARTBEAT_INTERVAL,
"inactive_timeout_sec": INACTIVE_TIMEOUT,
}
@pytest.fixture
def router(request):
return request.param
@pytest.fixture
def disagg_server_config(disagg_cluster_config, router):
return {
"hostname": "localhost",
"port": TEST_PORT,
"disagg_cluster": disagg_cluster_config,
"context_servers": {
"router": {
"type": router
}
},
"generation_servers": {
"router": {
"type": router
}
},
}
@pytest.fixture
def worker_config(disagg_cluster_config):
return {
"disagg_cluster": disagg_cluster_config,
"disable_overlap_scheduler": True,
"cache_transceiver_config": {
"backend": "DEFAULT"
},
"kv_cache_config": {
"free_gpu_memory_fraction": 0.2,
"enable_partial_reuse": False,
},
"cuda_graph_config": {},
}
def _run_worker(model_name, worker_config, role, port=8000, device=-1):
worker_config_path = tempfile.NamedTemporaryFile(delete=False)
with open(worker_config_path.name, "w+") as f:
yaml.dump(worker_config, f)
f.flush()
cmd = [
"trtllm-serve",
"serve",
model_name,
"--host",
"localhost",
"--port",
str(port),
"--extra_llm_api_options",
worker_config_path.name,
"--server_role",
"context" if role.startswith("ctx") else "generation",
]
env = os.environ.copy()
if device != -1:
env["CUDA_VISIBLE_DEVICES"] = str(device)
return subprocess.Popen(cmd, env=env)
def run_ctx_worker(model_name,
ctx_worker_config,
port=TEST_PORT + 100,
device=0):
return _run_worker(model_name, ctx_worker_config, "ctx", port, device)
def run_gen_worker(model_name,
gen_worker_config,
port=TEST_PORT + 200,
device=1):
return _run_worker(model_name, gen_worker_config, "gen", port, device)
def run_disagg_server(disagg_cluster_config, port=TEST_PORT):
disagg_server_config_path = f"/tmp/disagg_server_{port}_config.yaml"
disagg_cluster_config["port"] = port
with open(disagg_server_config_path, "w+") as f:
yaml.dump(disagg_cluster_config, f)
cmds = ["trtllm-serve", "disaggregated", "-c", disagg_server_config_path]
f = open("disagg_server.log", "w+")
p = subprocess.Popen(cmds, stdout=f, stderr=f)
return p
async def wait_for_disagg_server_ready(port):
while True:
await asyncio.sleep(3)
logger.info(f"Waiting for disagg server to be ready")
try:
info_resp = requests.get(f"http://localhost:{port}/cluster_info")
if info_resp.status_code == 200:
info = info_resp.json()
if info["is_ready"]:
break
logger.info(
f"Waiting for disagg server to be ready: {info_resp.json()}"
)
else:
logger.info(
f"Failed to get cluster info: {info_resp.status_code}")
await asyncio.sleep(3)
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get cluster info: {e}")
async def wait_for_worker_ready(port):
while True:
await asyncio.sleep(3)
logger.info(f"Waiting for worker {port} to be ready")
try:
info_resp = requests.get(f"http://localhost:{port}/health")
if info_resp.status_code == 200:
break
except requests.exceptions.RequestException as e:
logger.info(f"Failed to get worker info: {e}")
def verify_cluster_info(ready,
ctx_workers=-1,
gen_workers=-1,
port=TEST_PORT,
expected_code=200):
info_resp = requests.get(f"http://localhost:{port}/cluster_info")
assert info_resp.status_code == expected_code
info = info_resp.json()
print("verify_cluster_info", info, ready, ctx_workers, gen_workers)
assert info["is_ready"] == ready
if ctx_workers != -1:
assert len(info["current_workers"]["context_servers"]) == ctx_workers
if gen_workers != -1:
assert len(info["current_workers"]["generation_servers"]) == gen_workers
def terminate(*args):
try:
for arg in args:
if arg and isinstance(arg, subprocess.Popen):
arg.terminate()
arg.wait(timeout=10)
except Exception:
pass
def request_completion(model_name, prompt, port=TEST_PORT):
client = openai.OpenAI(api_key="tensorrt_llm",
base_url=f"http://localhost:{port}/v1")
return client.completions.create(model=model_name,
prompt=prompt,
max_tokens=10,
temperature=0.0)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("router", ROUTER_TYPES, indirect=True)
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.timeout(600)
async def test_service_discovery(model_name, disagg_server_config,
worker_config, router):
ctx_worker1 = None
gen_worker1 = None
disagg_server = None
try:
# initial cluster, 1 ctx, 1 gen, request should succeed
ctx_worker1 = run_ctx_worker(model_name, worker_config, TEST_PORT + 100)
gen_worker1 = run_gen_worker(model_name, worker_config, TEST_PORT + 200)
disagg_server = run_disagg_server(disagg_server_config, TEST_PORT)
await wait_for_disagg_server_ready(TEST_PORT)
verify_cluster_info(True, 1, 1)
response = request_completion(model_name,
"Hello, my name is",
port=TEST_PORT)
print(response)
finally:
terminate(ctx_worker1, gen_worker1, disagg_server)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize(
"router", ["round_robin"], indirect=True
) # use only round_robin to reduce the test time, this router type doesn't matter for this test
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.timeout(600)
async def test_minimal_instances(model_name, disagg_server_config,
worker_config, router):
# the cluster should have at least 2 ctx and 2 gen workers
minimal_instances = {
"context_servers": 2,
"generation_servers": 2,
}
disagg_server_config["disagg_cluster"][
"minimal_instances"] = minimal_instances
worker_config["disagg_cluster"]["minimal_instances"] = minimal_instances
processes = []
try:
processes.append(
run_ctx_worker(model_name, worker_config, TEST_PORT + 100))
processes.append(
run_gen_worker(model_name, worker_config, TEST_PORT + 200))
processes.append(run_disagg_server(disagg_server_config, TEST_PORT))
await wait_for_worker_ready(TEST_PORT + 100)
await wait_for_worker_ready(TEST_PORT + 200)
verify_cluster_info(False, 1, 1)
# with only 1 ctx and 1 gen worker, the request should fail
with pytest.raises(Exception):
response = request_completion(model_name,
"Hello, my name is",
port=TEST_PORT)
print(response)
processes.append(
run_ctx_worker(model_name, worker_config, TEST_PORT + 101))
processes.append(
run_gen_worker(model_name, worker_config, TEST_PORT + 201))
await wait_for_disagg_server_ready(TEST_PORT)
verify_cluster_info(True, 2, 2)
response = request_completion(model_name,
"Hello, my name is",
port=TEST_PORT)
print(response)
finally:
terminate(*processes)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("router", ROUTER_TYPES, indirect=True)
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.timeout(600)
async def test_worker_restart(model_name, disagg_server_config, worker_config,
router):
ctx_worker1 = None
ctx_worker2 = None
gen_worker1 = None
gen_worker2 = None
disagg_server = None
try:
# initial cluster, 1 ctx, 1 gen, request should succeed
ctx_worker1 = run_ctx_worker(model_name,
worker_config,
TEST_PORT + 100,
device=0)
gen_worker1 = run_gen_worker(model_name,
worker_config,
TEST_PORT + 200,
device=1)
disagg_server = run_disagg_server(disagg_server_config, TEST_PORT)
await wait_for_disagg_server_ready(TEST_PORT)
verify_cluster_info(True, 1, 1)
response = request_completion(model_name,
"Hello, my name is",
port=TEST_PORT)
print(response)
# kill gen1, the request should fail
terminate(gen_worker1)
await asyncio.sleep(INACTIVE_TIMEOUT)
verify_cluster_info(False, 1, 0)
with pytest.raises(Exception):
request_completion(model_name, "Hello, my name is", port=TEST_PORT)
test_prompt = "The capital of France is"
# add gen2, the request should succeed
gen_worker2 = run_gen_worker(model_name,
worker_config,
TEST_PORT + 201,
device=2)
await wait_for_worker_ready(TEST_PORT + 201)
await asyncio.sleep(INACTIVE_TIMEOUT)
verify_cluster_info(True, 1, 1)
response = request_completion(model_name, test_prompt, port=TEST_PORT)
print(response)
response_text = response.choices[0].text
assert len(response.choices[0].text) >= 1
# kill ctx1, the request should fail
terminate(ctx_worker1)
await asyncio.sleep(INACTIVE_TIMEOUT)
verify_cluster_info(False, 0, 1)
with pytest.raises(Exception):
request_completion(model_name, test_prompt, port=TEST_PORT)
# add ctx2, the request should succeed
ctx_worker2 = run_ctx_worker(model_name,
worker_config,
TEST_PORT + 101,
device=3)
await wait_for_worker_ready(TEST_PORT + 101)
verify_cluster_info(True, 1, 1)
response = request_completion(model_name, test_prompt, port=TEST_PORT)
response_text = response.choices[0].text
assert len(response.choices[0].text) >= 1
# restart ctx1 and gen1 with the same ports, we have 2 ctxs and 2 gens now
ctx_worker1 = run_ctx_worker(model_name, worker_config, TEST_PORT + 100)
gen_worker1 = run_gen_worker(model_name, worker_config, TEST_PORT + 200)
await wait_for_worker_ready(TEST_PORT + 100)
await wait_for_worker_ready(TEST_PORT + 200)
await asyncio.sleep(INACTIVE_TIMEOUT)
verify_cluster_info(True, 2, 2)
# send 10 requests, the responses will be generated by the different ctx/gen workers (but we can't verify it now)
for _ in range(10):
response = request_completion(model_name,
test_prompt,
port=TEST_PORT)
assert response.choices[0].text == response_text
print(response)
finally:
terminate(ctx_worker1, ctx_worker2, gen_worker1, gen_worker2,
disagg_server)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("router", ["round_robin"], indirect=True)
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.timeout(300)
async def test_disagg_server_restart(model_name, disagg_server_config,
worker_config, router):
ctx_worker1 = None
gen_worker1 = None
disagg_server = None
try:
# initial cluster, 1 ctx, 1 gen, request should succeed
ctx_worker1 = run_ctx_worker(model_name, worker_config, TEST_PORT + 100)
gen_worker1 = run_gen_worker(model_name, worker_config, TEST_PORT + 200)
disagg_server = run_disagg_server(disagg_server_config, TEST_PORT)
await wait_for_disagg_server_ready(TEST_PORT)
verify_cluster_info(True, 1, 1)
response = request_completion(model_name,
"Hello, my name is",
port=TEST_PORT)
print(response)
response_text = response.choices[0].text
# kill disagg server, the request should fail
terminate(disagg_server)
await asyncio.sleep(INACTIVE_TIMEOUT)
with pytest.raises(Exception):
verify_cluster_info(False, 1, 1, expected_code=500)
# restart disagg server, the request should succeed
disagg_server = run_disagg_server(disagg_server_config, TEST_PORT)
await wait_for_disagg_server_ready(TEST_PORT)
verify_cluster_info(True, 1, 1)
response = request_completion(model_name,
"Hello, my name is",
port=TEST_PORT)
print(response)
assert response.choices[0].text == response_text
finally:
terminate(disagg_server, ctx_worker1, gen_worker1)

View File

@ -91,6 +91,12 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
- disaggregated/test_auto_scaling.py::test_service_discovery[round_robin]
- disaggregated/test_auto_scaling.py::test_service_discovery[load_balancing]
- disaggregated/test_auto_scaling.py::test_worker_restart[round_robin]
- disaggregated/test_auto_scaling.py::test_worker_restart[load_balancing]
- disaggregated/test_auto_scaling.py::test_minimal_instances[round_robin]
- disaggregated/test_auto_scaling.py::test_disagg_server_restart[round_robin]
- condition:
ranges:
system_gpu_count:

View File

@ -35,15 +35,50 @@ def get_yaml_config():
return config
@pytest.fixture
def sample_yaml_config():
config = get_yaml_config()
def get_yaml_config_with_disagg_cluster():
config = {
"hostname": "test_host",
"port": 9000,
"context_servers": {
"max_batch_size": 1,
"tensor_parallel_size": 2,
"pipeline_parallel_size": 1,
},
"generation_servers": {
"router": {
"type": "load_balancing",
"use_tokens": False,
},
"max_batch_size": 1,
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
},
"disagg_cluster": {
"cluster_uri": "http://test_host:9000",
"cluster_name": "test_cluster",
"minimal_instances": {
"context_servers": 2,
"generation_servers": 2,
},
"heartbeat_interval_sec": 1,
"inactive_timeout_sec": 2,
},
}
return config
@pytest.fixture
def sample_yaml_file(tmp_path):
config = get_yaml_config()
def sample_yaml_config(request):
if request.param == "disagg_cluster":
config = get_yaml_config_with_disagg_cluster()
else:
config = get_yaml_config()
return config
@pytest.fixture
def sample_yaml_file(sample_yaml_config, tmp_path):
config = sample_yaml_config
yaml_file = tmp_path / "test_config.yaml"
with open(yaml_file, "w") as f:
@ -51,24 +86,30 @@ def sample_yaml_file(tmp_path):
return yaml_file
def verify_disagg_config(config: DisaggServerConfig):
def verify_disagg_config(config: DisaggServerConfig,
sample_yaml_config: str = ""):
assert config.hostname == "test_host"
assert config.port == 9000
assert config.ctx_router_config.type == "round_robin"
assert config.gen_router_config.type == "load_balancing"
assert len(config.server_configs) == 3
if sample_yaml_config == "":
assert len(config.server_configs) == 3
def test_parse_disagg_config_file(sample_yaml_file):
@pytest.mark.parametrize("sample_yaml_config", ["disagg_cluster", ""],
indirect=True)
def test_parse_disagg_config_file(sample_yaml_file, sample_yaml_config):
config = parse_disagg_config_file(sample_yaml_file)
assert isinstance(config, DisaggServerConfig)
verify_disagg_config(config)
verify_disagg_config(config, sample_yaml_config)
@pytest.mark.parametrize("sample_yaml_config", ["disagg_cluster", ""],
indirect=True)
def test_extract_disagg_cfg(sample_yaml_config):
config = extract_disagg_cfg(**sample_yaml_config)
assert isinstance(config, DisaggServerConfig)
verify_disagg_config(config)
verify_disagg_config(config, sample_yaml_config)
def test_extract_ctx_gen_cfgs():
@ -85,6 +126,7 @@ def test_extract_ctx_gen_cfgs():
assert configs[0].instance_num_ranks == 2
@pytest.mark.parametrize("sample_yaml_config", [""], indirect=True)
def test_extract_router_config(sample_yaml_config):
ctx_server_config = sample_yaml_config["context_servers"]
gen_server_config = sample_yaml_config["generation_servers"]