mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-7843][feat] implement disagg cluster auto-scaling (#8215)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
9b54b3bfaf
commit
23d5280a90
@ -67,7 +67,7 @@ nvtx
|
||||
matplotlib # FIXME: this is added to make nvtx happy
|
||||
meson
|
||||
ninja
|
||||
etcd3
|
||||
etcd3 @ git+https://github.com/kragniz/python-etcd3.git@e58a899579ba416449c4e225b61f039457c8072a
|
||||
blake3
|
||||
soundfile
|
||||
triton==3.3.1; platform_machine == "x86_64"
|
||||
|
||||
@ -21,7 +21,9 @@ from tensorrt_llm.executor.utils import LlmLauncherEnvs
|
||||
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
|
||||
DynamicBatchConfig, KvCacheConfig,
|
||||
SchedulerConfig)
|
||||
from tensorrt_llm.llmapi.disagg_utils import (MetadataServerConfig, ServerRole,
|
||||
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
|
||||
MetadataServerConfig, ServerRole,
|
||||
extract_disagg_cluster_config,
|
||||
parse_disagg_config_file,
|
||||
parse_metadata_server_config_file)
|
||||
from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict
|
||||
@ -140,7 +142,8 @@ def launch_server(host: str,
|
||||
port: int,
|
||||
llm_args: dict,
|
||||
metadata_server_cfg: Optional[MetadataServerConfig] = None,
|
||||
server_role: Optional[ServerRole] = None):
|
||||
server_role: Optional[ServerRole] = None,
|
||||
disagg_cluster_config: Optional[DisaggClusterConfig] = None):
|
||||
|
||||
backend = llm_args["backend"]
|
||||
model = llm_args["model"]
|
||||
@ -161,7 +164,8 @@ def launch_server(host: str,
|
||||
server = OpenAIServer(llm=llm,
|
||||
model=model,
|
||||
server_role=server_role,
|
||||
metadata_server_cfg=metadata_server_cfg)
|
||||
metadata_server_cfg=metadata_server_cfg,
|
||||
disagg_cluster_config=disagg_cluster_config)
|
||||
|
||||
# Optionally disable GC (default: not disabled)
|
||||
if os.getenv("TRTLLM_SERVER_DISABLE_GC", "0") == "1":
|
||||
@ -313,6 +317,10 @@ class ChoiceWithAlias(click.Choice):
|
||||
help=
|
||||
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
|
||||
)
|
||||
@click.option("--disagg_cluster_uri",
|
||||
type=str,
|
||||
default=None,
|
||||
help="URI of the disaggregated cluster.")
|
||||
@click.option("--enable_chunked_prefill",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
@ -327,7 +335,7 @@ def serve(
|
||||
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
|
||||
metadata_server_config_file: Optional[str], server_role: Optional[str],
|
||||
fail_fast_on_attention_window_too_large: bool,
|
||||
enable_chunked_prefill: bool):
|
||||
enable_chunked_prefill: bool, disagg_cluster_uri: Optional[str]):
|
||||
"""Running an OpenAI API compatible server
|
||||
|
||||
MODEL: model name | HF checkpoint path | TensorRT engine path
|
||||
@ -364,14 +372,27 @@ def serve(
|
||||
metadata_server_cfg = parse_metadata_server_config_file(
|
||||
metadata_server_config_file)
|
||||
|
||||
if metadata_server_cfg is not None:
|
||||
assert server_role is not None, "server_role is required when metadata_server_cfg is provided"
|
||||
# Specify disagg_cluster_config in config file or through command line "--disagg_cluster_uri",
|
||||
# but disagg_cluster_uri takes precedence over cluster uri in config file
|
||||
disagg_cluster_config = llm_args.pop("disagg_cluster", None)
|
||||
if disagg_cluster_config:
|
||||
disagg_cluster_config = extract_disagg_cluster_config(
|
||||
disagg_cluster_config, disagg_cluster_uri)
|
||||
elif disagg_cluster_uri:
|
||||
disagg_cluster_config = DisaggClusterConfig(
|
||||
cluster_uri=disagg_cluster_uri)
|
||||
|
||||
if metadata_server_cfg is not None or disagg_cluster_config is not None:
|
||||
assert (
|
||||
server_role is not None
|
||||
), "server_role is required when metadata_server_cfg or disagg_cluster_config is provided"
|
||||
try:
|
||||
server_role = ServerRole[server_role.upper()]
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid server role: {server_role}. " \
|
||||
f"Must be one of: {', '.join([role.name for role in ServerRole])}")
|
||||
launch_server(host, port, llm_args, metadata_server_cfg, server_role)
|
||||
launch_server(host, port, llm_args, metadata_server_cfg, server_role,
|
||||
disagg_cluster_config)
|
||||
|
||||
|
||||
@click.command("mm_embedding_serve")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, List, Literal, Optional, Tuple
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
from mpi4py.MPI import COMM_WORLD, Comm
|
||||
@ -68,6 +68,7 @@ class DisaggServerConfig():
|
||||
conditional_disagg_config: Optional[ConditionalDisaggConfig] = None
|
||||
max_retries: int = 1
|
||||
perf_metrics_max_requests: int = 0
|
||||
disagg_cluster_config: Optional[DisaggClusterConfig] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -111,6 +112,7 @@ def extract_disagg_cfg(hostname: str = 'localhost',
|
||||
context_servers: Optional[dict] = None,
|
||||
generation_servers: Optional[dict] = None,
|
||||
conditional_disagg_config: Optional[dict] = None,
|
||||
disagg_cluster: Optional[dict] = None,
|
||||
**kwargs: Any) -> DisaggServerConfig:
|
||||
context_servers = context_servers or {}
|
||||
generation_servers = generation_servers or {}
|
||||
@ -131,15 +133,18 @@ def extract_disagg_cfg(hostname: str = 'localhost',
|
||||
# Inherit the value from the top-level
|
||||
servers[key] = value
|
||||
|
||||
server_configs = []
|
||||
disagg_cluster_config = None
|
||||
ctx_router_config = extract_router_config(context_servers)
|
||||
gen_router_config = extract_router_config(generation_servers)
|
||||
|
||||
server_configs = extract_ctx_gen_cfgs(
|
||||
type="ctx", **context_servers) + extract_ctx_gen_cfgs(
|
||||
type="gen", **generation_servers)
|
||||
|
||||
ctx_router_config.server_role = ServerRole.CONTEXT
|
||||
gen_router_config.server_role = ServerRole.GENERATION
|
||||
if disagg_cluster:
|
||||
disagg_cluster_config = extract_disagg_cluster_config(disagg_cluster)
|
||||
else:
|
||||
server_configs = extract_ctx_gen_cfgs(
|
||||
type="ctx", **context_servers) + extract_ctx_gen_cfgs(
|
||||
type="gen", **generation_servers)
|
||||
|
||||
conditional_disagg_config = ConditionalDisaggConfig(
|
||||
**conditional_disagg_config) if conditional_disagg_config else None
|
||||
@ -147,7 +152,8 @@ def extract_disagg_cfg(hostname: str = 'localhost',
|
||||
config = DisaggServerConfig(server_configs, hostname, port,
|
||||
ctx_router_config, gen_router_config,
|
||||
conditional_disagg_config, max_retries,
|
||||
perf_metrics_max_requests)
|
||||
perf_metrics_max_requests,
|
||||
disagg_cluster_config)
|
||||
|
||||
return config
|
||||
|
||||
@ -235,6 +241,33 @@ def get_server_configs_dict(
|
||||
return num_workers, server_dict
|
||||
|
||||
|
||||
def extract_disagg_cluster_config(
|
||||
cluster_config_dict: Dict[str, Any],
|
||||
cluster_uri: Optional[str] = None) -> DisaggClusterConfig:
|
||||
"""
|
||||
Build the DisaggClusterConfig from the cluster_config_dict.
|
||||
Use the default value of DisaggClusterConfig and MinimalInstances if the corresponding fields are not provided.
|
||||
If cluster_uri is provided, it will override the cluster_uri in the cluster_config_dict.
|
||||
"""
|
||||
|
||||
def update_dataclass(obj, data_dict: Dict[str, Any]):
|
||||
for key, value in data_dict.items():
|
||||
if key not in obj.__dataclass_fields__:
|
||||
raise KeyError(
|
||||
f"Key {key} not found in {obj.__class__.__name__}")
|
||||
if value is not None:
|
||||
setattr(obj, key, value)
|
||||
return obj
|
||||
|
||||
cluster_config_dict["minimal_instances"] = update_dataclass(
|
||||
MinimalInstances(), cluster_config_dict.get("minimal_instances", {}))
|
||||
cluster_config = update_dataclass(
|
||||
DisaggClusterConfig(cluster_uri or cluster_config_dict["cluster_uri"]),
|
||||
cluster_config_dict,
|
||||
)
|
||||
return cluster_config
|
||||
|
||||
|
||||
def split_world_comm(
|
||||
server_configs: List[CtxGenServerConfig]) -> Tuple[bool, int, Comm]:
|
||||
|
||||
|
||||
@ -281,9 +281,10 @@ class HttpClusterStorageServer(ClusterStorage):
|
||||
self._storage.pop(k)
|
||||
for k, v in kv_to_delete.items():
|
||||
await self._notify_watch_event(k, v, WatchEventType.DELETE)
|
||||
logger.debug(
|
||||
f"Checked expired, {before_len} -> {len(self._storage)}, keys to delete: {kv_to_delete.keys()}"
|
||||
)
|
||||
if len(kv_to_delete) > 0:
|
||||
logger.debug(
|
||||
f"Checked expired, {before_len} -> {len(self._storage)}, keys to delete: {kv_to_delete.keys()}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking expired: {e}")
|
||||
|
||||
@ -298,9 +299,12 @@ class HttpClusterStorageClient(ClusterStorage):
|
||||
self._cluster_name = cluster_name
|
||||
|
||||
def __del__(self):
|
||||
if asyncio.get_event_loop():
|
||||
asyncio.run_coroutine_threadsafe(self._session.close(),
|
||||
asyncio.get_event_loop())
|
||||
try:
|
||||
if asyncio.get_event_loop():
|
||||
asyncio.run_coroutine_threadsafe(self._session.close(),
|
||||
asyncio.get_event_loop())
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def _url_for(self, endpoint: str) -> str:
|
||||
return f"{self._cluster_uri}/{endpoint}"
|
||||
|
||||
@ -46,9 +46,13 @@ class DisaggClusterManager:
|
||||
self._watch_handle = None
|
||||
|
||||
def __del__(self):
|
||||
if asyncio.get_event_loop():
|
||||
asyncio.run_coroutine_threadsafe(self.stop(),
|
||||
asyncio.get_event_loop())
|
||||
try:
|
||||
if asyncio.get_event_loop():
|
||||
asyncio.run_coroutine_threadsafe(self.stop(),
|
||||
asyncio.get_event_loop())
|
||||
except RuntimeError:
|
||||
# the event loop may not be running when the cluster manager is destroyed
|
||||
pass
|
||||
|
||||
async def start(self) -> None:
|
||||
await self._cluster_storage.start()
|
||||
@ -208,9 +212,13 @@ class DisaggClusterWorker:
|
||||
self._worker_id = f"{role.name}-{host}:{port}-{int(time.time()*1000)}-{os.getpid()}-{random.randint(0, 1000):03}"
|
||||
|
||||
def __del__(self):
|
||||
if asyncio.get_event_loop():
|
||||
asyncio.run_coroutine_threadsafe(self.deregister_worker(),
|
||||
asyncio.get_event_loop())
|
||||
try:
|
||||
if asyncio.get_event_loop():
|
||||
asyncio.run_coroutine_threadsafe(self.deregister_worker(),
|
||||
asyncio.get_event_loop())
|
||||
except RuntimeError:
|
||||
# the event loop may not be running when the worker is destroyed
|
||||
pass
|
||||
|
||||
@property
|
||||
def worker_id(self) -> str:
|
||||
|
||||
@ -20,9 +20,13 @@ from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
|
||||
# yapf: disable
|
||||
from tensorrt_llm.executor import CppExecutorError
|
||||
from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig,
|
||||
MetadataServerConfig,
|
||||
MetadataServerConfig, ServerRole,
|
||||
get_ctx_gen_server_urls)
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.serve.cluster_storage import (WatchEventType,
|
||||
create_cluster_storage)
|
||||
from tensorrt_llm.serve.disagg_auto_scaling import (DisaggClusterManager,
|
||||
WorkerInfo)
|
||||
from tensorrt_llm.serve.metadata_server import create_metadata_server
|
||||
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@ -53,7 +57,6 @@ class OpenAIDisaggServer:
|
||||
self.gen_router = create_router(
|
||||
config.gen_router_config, self.gen_servers, metadata_server_cfg, self.metadata_server)
|
||||
self.conditional_disagg_config = config.conditional_disagg_config
|
||||
|
||||
self.perf_metrics_max_requests = config.perf_metrics_max_requests
|
||||
if self.perf_metrics_max_requests > 0:
|
||||
# record corresponding keys of context and generation servers for perf metrics
|
||||
@ -82,18 +85,27 @@ class OpenAIDisaggServer:
|
||||
self._metrics_task = None
|
||||
self.metrics_interval_secs = metrics_interval_secs
|
||||
|
||||
self.disagg_cluster_config = config.disagg_cluster_config
|
||||
self.disagg_cluster_storage = None
|
||||
self.disagg_cluster_manager = None
|
||||
self._update_worker_task = None
|
||||
|
||||
logger.info(f"Server max retries: {self.max_retries}")
|
||||
|
||||
if (len(self.gen_servers) == 0):
|
||||
raise ValueError("At least one generation server must be provided")
|
||||
if self.disagg_cluster_config is None:
|
||||
if (len(self.gen_servers) == 0):
|
||||
raise ValueError("At least one generation server must be provided")
|
||||
|
||||
if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") != "1" and len(self.ctx_servers) == 0:
|
||||
raise ValueError("At least one context server must be provided")
|
||||
if os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") != "1" and len(self.ctx_servers) == 0:
|
||||
raise ValueError("At least one context server must be provided")
|
||||
|
||||
if self.conditional_disagg_config is not None and \
|
||||
not isinstance(self.gen_router, KvCacheAwareRouter):
|
||||
raise ValueError("Generation router must be a KvCacheAwareRouter to enable conditional disaggregation")
|
||||
|
||||
if self.disagg_cluster_config and self.metadata_server:
|
||||
raise ValueError("Cluster manager and metadata server cannot be used together")
|
||||
|
||||
# Session will be initialized in lifespan
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
@ -104,6 +116,11 @@ class OpenAIDisaggServer:
|
||||
connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, force_close=True),
|
||||
timeout=aiohttp.ClientTimeout(total=req_timeout_secs))
|
||||
|
||||
if self.disagg_cluster_manager:
|
||||
await self.disagg_cluster_manager.start()
|
||||
await self.disagg_cluster_manager.watch_workers()
|
||||
self._update_worker_task = asyncio.create_task(self._update_router_by_watch_events())
|
||||
|
||||
logger.info("Waiting for context and generation servers to be ready")
|
||||
await self.wait_for_servers_ready(server_start_timeout_secs)
|
||||
|
||||
@ -135,6 +152,9 @@ class OpenAIDisaggServer:
|
||||
pass
|
||||
|
||||
await self.session.close() # Ensure session cleanup
|
||||
if self.disagg_cluster_manager:
|
||||
self._update_worker_task.cancel()
|
||||
await self.disagg_cluster_manager.stop()
|
||||
|
||||
self.app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@ -145,6 +165,10 @@ class OpenAIDisaggServer:
|
||||
return JSONResponse(status_code=400, content={"error": str(exc)})
|
||||
|
||||
self.register_routes()
|
||||
if self.disagg_cluster_config:
|
||||
self.disagg_cluster_storage = create_cluster_storage(self.disagg_cluster_config.cluster_uri, self.disagg_cluster_config.cluster_name, server=self.app)
|
||||
self.disagg_cluster_manager = DisaggClusterManager(self.disagg_cluster_config, self.disagg_cluster_storage)
|
||||
|
||||
|
||||
async def _increment_metric(self, key: str, amount: int = 1):
|
||||
if self.metrics_interval_secs > 0:
|
||||
@ -185,14 +209,24 @@ class OpenAIDisaggServer:
|
||||
self.app.add_api_route("/v1/chat/completions",
|
||||
self.openai_chat_completion,
|
||||
methods=["POST"])
|
||||
self.app.add_api_route("/cluster_info", self.cluster_info, methods=["GET"])
|
||||
|
||||
async def health(self) -> Response:
|
||||
if not await self.is_ready():
|
||||
return Response(status_code=500)
|
||||
return Response(status_code=200)
|
||||
|
||||
async def version(self) -> JSONResponse:
|
||||
ver = {"version": VERSION}
|
||||
return JSONResponse(content=ver)
|
||||
|
||||
async def cluster_info(self) -> JSONResponse:
|
||||
if self.disagg_cluster_manager:
|
||||
cluster_info = await self.disagg_cluster_manager.cluster_info()
|
||||
cluster_info["is_ready"] = await self.is_ready()
|
||||
return JSONResponse(content=cluster_info)
|
||||
return JSONResponse(content={})
|
||||
|
||||
async def _add_perf_metrics_keys(self, ctx_server: str, gen_server: str, ctx_request_id: int, raw_request: Request):
|
||||
async with self.perf_metrics_keys_lock:
|
||||
self.perf_metrics_keys.append((ctx_server, gen_server, ctx_request_id, raw_request.state.server_arrival_time, raw_request.state.server_first_token_time))
|
||||
@ -252,6 +286,8 @@ class OpenAIDisaggServer:
|
||||
|
||||
|
||||
async def openai_completion(self, req: CompletionRequest, raw_request: Request) -> Response:
|
||||
if not await self.is_ready():
|
||||
raise HTTPException(status_code=400, detail="Cluster is not ready")
|
||||
try:
|
||||
if not isinstance(req.prompt, str):
|
||||
# Check if it's a list and contains integers
|
||||
@ -266,7 +302,8 @@ class OpenAIDisaggServer:
|
||||
await self._handle_exception(e)
|
||||
|
||||
async def openai_chat_completion(self, req: ChatCompletionRequest, raw_request: Request) -> Response:
|
||||
|
||||
if not await self.is_ready():
|
||||
raise HTTPException(status_code=400, detail="Cluster is not ready")
|
||||
try:
|
||||
return await self._send_disagg_request(req, raw_request)
|
||||
except Exception as e:
|
||||
@ -574,5 +611,31 @@ class OpenAIDisaggServer:
|
||||
raise TimeoutError("Timeout waiting for context and generation servers to be ready")
|
||||
logger.info("Context and generation servers are ready")
|
||||
|
||||
async def is_ready(self) -> bool:
|
||||
if self.disagg_cluster_manager:
|
||||
return await self.disagg_cluster_manager.is_ready_with_router(len(self.ctx_router.servers), len(self.gen_router.servers))
|
||||
return True
|
||||
|
||||
async def wait_for_servers_ready(self, server_start_timeout_secs: int = 180):
|
||||
await self.wait_for_all_servers_ready(self.session, self.ctx_servers, self.gen_servers, server_start_timeout_secs)
|
||||
|
||||
async def _update_router_by_watch_events(self):
|
||||
def worker_repr(worker_info: WorkerInfo):
|
||||
return f"http://{worker_info.host}:{worker_info.port}"
|
||||
router_map = {
|
||||
ServerRole.CONTEXT: self.ctx_router,
|
||||
ServerRole.GENERATION: self.gen_router
|
||||
}
|
||||
logger.info("Start updating routers by worker events")
|
||||
while True:
|
||||
try:
|
||||
worker_events = await self.disagg_cluster_manager.get_worker_events()
|
||||
for worker_info, event_type in worker_events:
|
||||
if event_type == WatchEventType.SET:
|
||||
await router_map[worker_info.role].add_server(worker_repr(worker_info))
|
||||
elif event_type == WatchEventType.DELETE:
|
||||
await router_map[worker_info.role].remove_server(worker_repr(worker_info))
|
||||
logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating routers by worker events: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@ -28,12 +28,15 @@ from tensorrt_llm.inputs.data import TokensPrompt
|
||||
from tensorrt_llm.inputs.utils import ConversationMessage, apply_chat_template
|
||||
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
|
||||
from tensorrt_llm.llmapi import MultimodalEncoder
|
||||
from tensorrt_llm.llmapi.disagg_utils import MetadataServerConfig, ServerRole
|
||||
from tensorrt_llm.llmapi.disagg_utils import (DisaggClusterConfig,
|
||||
MetadataServerConfig, ServerRole)
|
||||
from tensorrt_llm.llmapi.llm import RequestOutput
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.metrics.collector import MetricsCollector
|
||||
from tensorrt_llm.serve.chat_utils import (check_multiple_response,
|
||||
parse_chat_messages_coroutines)
|
||||
from tensorrt_llm.serve.cluster_storage import create_cluster_storage_client
|
||||
from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterWorker
|
||||
from tensorrt_llm.serve.metadata_server import create_metadata_server
|
||||
from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@ -75,12 +78,17 @@ class OpenAIServer:
|
||||
llm: Union[LLM, MultimodalEncoder],
|
||||
model: str,
|
||||
server_role: Optional[ServerRole],
|
||||
metadata_server_cfg: MetadataServerConfig):
|
||||
metadata_server_cfg: MetadataServerConfig,
|
||||
disagg_cluster_config: Optional[DisaggClusterConfig] = None):
|
||||
self.llm = llm
|
||||
self.tokenizer = llm.tokenizer
|
||||
self.metadata_server = create_metadata_server(metadata_server_cfg)
|
||||
self.disagg_cluster_config = disagg_cluster_config
|
||||
self.server_role = server_role
|
||||
self.binding_addr = None # Will be set in __call__
|
||||
# Will be set in __call__
|
||||
self.binding_addr = None
|
||||
self.host = None
|
||||
self.port = None
|
||||
hf_tokenizer_path = llm._hf_model_dir or self.tokenizer.tokenizer.name_or_path
|
||||
trust_remote_code = llm.args.trust_remote_code
|
||||
try:
|
||||
@ -129,6 +137,10 @@ class OpenAIServer:
|
||||
else:
|
||||
self.use_harmony = (self.model_config.model_type == "gpt_oss")
|
||||
|
||||
# as disagg-worker
|
||||
self.disagg_cluster_storage = None
|
||||
self.disagg_cluster_worker = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
if self.metadata_server is not None:
|
||||
@ -144,12 +156,19 @@ class OpenAIServer:
|
||||
self.metadata_server.put(f"trtllm/{self.llm.llm_id}", metadata)
|
||||
logger.info(f"trtllm/{self.llm.llm_id} is registered")
|
||||
|
||||
if self.disagg_cluster_config:
|
||||
self.disagg_cluster_storage = create_cluster_storage_client(self.disagg_cluster_config.cluster_uri, self.disagg_cluster_config.cluster_name)
|
||||
self.disagg_cluster_worker= DisaggClusterWorker(self.server_role, self.host, self.port, self.disagg_cluster_config, self.disagg_cluster_storage)
|
||||
await self.disagg_cluster_worker.register_worker()
|
||||
|
||||
# terminate rank0 worker
|
||||
yield
|
||||
|
||||
if self.metadata_server is not None:
|
||||
self.metadata_server.remove(f"trtllm/{self.llm.llm_id}")
|
||||
logger.info(f"trtllm/{self.llm.llm_id} is unregistered")
|
||||
if self.disagg_cluster_worker:
|
||||
await self.disagg_cluster_worker.deregister_worker()
|
||||
self.llm.shutdown()
|
||||
|
||||
self.app = FastAPI(lifespan=lifespan)
|
||||
@ -949,6 +968,8 @@ class OpenAIServer:
|
||||
async def __call__(self, host, port):
|
||||
# Store the binding address for server registration
|
||||
self.binding_addr = f"http://{host}:{port}"
|
||||
self.host = host
|
||||
self.port = port
|
||||
config = uvicorn.Config(self.app,
|
||||
host=host,
|
||||
port=port,
|
||||
|
||||
@ -164,6 +164,34 @@ class Router(ABC):
|
||||
new_servers: The new server list
|
||||
"""
|
||||
|
||||
@property
|
||||
def servers(self) -> List[str]:
|
||||
return self._servers
|
||||
|
||||
async def add_server(self, server: str):
|
||||
if server in self._servers:
|
||||
logger.warning(f"Server {server} already exists")
|
||||
return
|
||||
async with self._lock:
|
||||
old_servers = self._servers.copy()
|
||||
self._servers = [*old_servers, server]
|
||||
self._on_servers_updated(old_servers, self._servers)
|
||||
logger.debug(
|
||||
f"Added server {server}, current server list: {self._servers}")
|
||||
|
||||
async def remove_server(self, server: str):
|
||||
if server not in self._servers:
|
||||
logger.warning(f"Server {server} does not exist")
|
||||
return
|
||||
async with self._lock:
|
||||
old_servers = self._servers.copy()
|
||||
self._servers = [
|
||||
old_server for old_server in old_servers if old_server != server
|
||||
]
|
||||
self._on_servers_updated(old_servers, self._servers)
|
||||
logger.debug(
|
||||
f"Removed server {server}, current server list: {self._servers}")
|
||||
|
||||
@abstractmethod
|
||||
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
||||
'''Select server by request and return some intermediate information'''
|
||||
@ -523,6 +551,7 @@ class KvCacheAwareRouter(Router):
|
||||
super().__init__(server_role, servers, metadata_server_cfg,
|
||||
metadata_server)
|
||||
self._lock = asyncio.Lock()
|
||||
self._use_tokens = use_tokens
|
||||
|
||||
# Load map between servers and their number of tokens processed
|
||||
self._server_state: dict[str, KvCacheAwareServerState] = {
|
||||
@ -557,7 +586,8 @@ class KvCacheAwareRouter(Router):
|
||||
return [tokenizer(prompt)["input_ids"] for prompt in prompts]
|
||||
|
||||
async def get_next_server(self, request: OpenAIRequest) -> tuple[str, dict]:
|
||||
servers = list(self._server_state.keys())
|
||||
async with self._lock:
|
||||
servers = list(self._server_state.keys())
|
||||
token_lists = self._tokenize(request)
|
||||
block_hashes: list[list[int]] = []
|
||||
for token_list in token_lists:
|
||||
@ -589,8 +619,8 @@ class KvCacheAwareRouter(Router):
|
||||
i] / self._max_batch_size
|
||||
scores.append(score)
|
||||
server = servers[scores.index(max(scores))]
|
||||
await self._server_state[server].increment_load(request)
|
||||
async with self._lock:
|
||||
await self._server_state[server].increment_load(request)
|
||||
self._req_routing_table[id(request)] = server
|
||||
return server, {
|
||||
"block_hashes": block_hashes, # list[list[int]]
|
||||
@ -604,8 +634,9 @@ class KvCacheAwareRouter(Router):
|
||||
async with self._lock:
|
||||
server = self._req_routing_table[id(request)]
|
||||
del self._req_routing_table[id(request)]
|
||||
await self._server_state[server].decrement_load(request,
|
||||
session=session)
|
||||
if server in self._server_state:
|
||||
await self._server_state[server].decrement_load(request,
|
||||
session=session)
|
||||
|
||||
def _on_servers_updated(self, old_servers, new_servers):
|
||||
raise NotImplementedError(
|
||||
@ -632,22 +663,19 @@ def create_router(router_config: Optional[RouterConfig],
|
||||
Raises:
|
||||
ValueError: If an unsupported router type is provided
|
||||
"""
|
||||
if router_config is None:
|
||||
# Create a default router without server_role
|
||||
return RoundRobinRouter(None, servers)
|
||||
|
||||
router_map = {
|
||||
"round_robin": RoundRobinRouter,
|
||||
"load_balancing": LoadBalancingRouter,
|
||||
"kv_cache_aware": KvCacheAwareRouter,
|
||||
}
|
||||
|
||||
router_type = router_config.type
|
||||
router_type = router_config.type if router_config else "round_robin"
|
||||
router_class = router_map.get(router_type.lower())
|
||||
|
||||
if router_class is None:
|
||||
raise ValueError(f"Unsupported router type: {router_type}. "
|
||||
f"Supported types are: {list(router_map.keys())}")
|
||||
extra_args = router_config.args if router_config else {}
|
||||
|
||||
# Pass server_role as the first argument
|
||||
return router_class(router_config.server_role, servers, metadata_server_cfg,
|
||||
metadata_server, **router_config.args)
|
||||
return router_class(router_config.server_role if router_config else None,
|
||||
servers, metadata_server_cfg, metadata_server,
|
||||
**extra_args)
|
||||
|
||||
397
tests/integration/defs/disaggregated/test_auto_scaling.py
Normal file
397
tests/integration/defs/disaggregated/test_auto_scaling.py
Normal file
@ -0,0 +1,397 @@
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
TEST_PORT = 18000
|
||||
HEARTBEAT_INTERVAL = 1
|
||||
INACTIVE_TIMEOUT = 2
|
||||
|
||||
ROUTER_TYPES = ["round_robin",
|
||||
"load_balancing"] # kv_cache_aware doesn't support auto-scaling
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disagg_cluster_config():
|
||||
# same cluster config for workers and proxy server
|
||||
return {
|
||||
"cluster_uri": f"http://localhost:{TEST_PORT}",
|
||||
"cluster_name": "test_cluster",
|
||||
"heartbeat_interval_sec": HEARTBEAT_INTERVAL,
|
||||
"inactive_timeout_sec": INACTIVE_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disagg_server_config(disagg_cluster_config, router):
|
||||
return {
|
||||
"hostname": "localhost",
|
||||
"port": TEST_PORT,
|
||||
"disagg_cluster": disagg_cluster_config,
|
||||
"context_servers": {
|
||||
"router": {
|
||||
"type": router
|
||||
}
|
||||
},
|
||||
"generation_servers": {
|
||||
"router": {
|
||||
"type": router
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def worker_config(disagg_cluster_config):
|
||||
return {
|
||||
"disagg_cluster": disagg_cluster_config,
|
||||
"disable_overlap_scheduler": True,
|
||||
"cache_transceiver_config": {
|
||||
"backend": "DEFAULT"
|
||||
},
|
||||
"kv_cache_config": {
|
||||
"free_gpu_memory_fraction": 0.2,
|
||||
"enable_partial_reuse": False,
|
||||
},
|
||||
"cuda_graph_config": {},
|
||||
}
|
||||
|
||||
|
||||
def _run_worker(model_name, worker_config, role, port=8000, device=-1):
|
||||
worker_config_path = tempfile.NamedTemporaryFile(delete=False)
|
||||
with open(worker_config_path.name, "w+") as f:
|
||||
yaml.dump(worker_config, f)
|
||||
f.flush()
|
||||
cmd = [
|
||||
"trtllm-serve",
|
||||
"serve",
|
||||
model_name,
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
str(port),
|
||||
"--extra_llm_api_options",
|
||||
worker_config_path.name,
|
||||
"--server_role",
|
||||
"context" if role.startswith("ctx") else "generation",
|
||||
]
|
||||
env = os.environ.copy()
|
||||
if device != -1:
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(device)
|
||||
return subprocess.Popen(cmd, env=env)
|
||||
|
||||
|
||||
def run_ctx_worker(model_name,
|
||||
ctx_worker_config,
|
||||
port=TEST_PORT + 100,
|
||||
device=0):
|
||||
return _run_worker(model_name, ctx_worker_config, "ctx", port, device)
|
||||
|
||||
|
||||
def run_gen_worker(model_name,
|
||||
gen_worker_config,
|
||||
port=TEST_PORT + 200,
|
||||
device=1):
|
||||
return _run_worker(model_name, gen_worker_config, "gen", port, device)
|
||||
|
||||
|
||||
def run_disagg_server(disagg_cluster_config, port=TEST_PORT):
|
||||
disagg_server_config_path = f"/tmp/disagg_server_{port}_config.yaml"
|
||||
disagg_cluster_config["port"] = port
|
||||
with open(disagg_server_config_path, "w+") as f:
|
||||
yaml.dump(disagg_cluster_config, f)
|
||||
cmds = ["trtllm-serve", "disaggregated", "-c", disagg_server_config_path]
|
||||
f = open("disagg_server.log", "w+")
|
||||
p = subprocess.Popen(cmds, stdout=f, stderr=f)
|
||||
return p
|
||||
|
||||
|
||||
async def wait_for_disagg_server_ready(port):
|
||||
while True:
|
||||
await asyncio.sleep(3)
|
||||
logger.info(f"Waiting for disagg server to be ready")
|
||||
try:
|
||||
info_resp = requests.get(f"http://localhost:{port}/cluster_info")
|
||||
if info_resp.status_code == 200:
|
||||
info = info_resp.json()
|
||||
if info["is_ready"]:
|
||||
break
|
||||
logger.info(
|
||||
f"Waiting for disagg server to be ready: {info_resp.json()}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Failed to get cluster info: {info_resp.status_code}")
|
||||
await asyncio.sleep(3)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get cluster info: {e}")
|
||||
|
||||
|
||||
async def wait_for_worker_ready(port):
|
||||
while True:
|
||||
await asyncio.sleep(3)
|
||||
logger.info(f"Waiting for worker {port} to be ready")
|
||||
try:
|
||||
info_resp = requests.get(f"http://localhost:{port}/health")
|
||||
if info_resp.status_code == 200:
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.info(f"Failed to get worker info: {e}")
|
||||
|
||||
|
||||
def verify_cluster_info(ready,
|
||||
ctx_workers=-1,
|
||||
gen_workers=-1,
|
||||
port=TEST_PORT,
|
||||
expected_code=200):
|
||||
info_resp = requests.get(f"http://localhost:{port}/cluster_info")
|
||||
assert info_resp.status_code == expected_code
|
||||
info = info_resp.json()
|
||||
print("verify_cluster_info", info, ready, ctx_workers, gen_workers)
|
||||
assert info["is_ready"] == ready
|
||||
if ctx_workers != -1:
|
||||
assert len(info["current_workers"]["context_servers"]) == ctx_workers
|
||||
if gen_workers != -1:
|
||||
assert len(info["current_workers"]["generation_servers"]) == gen_workers
|
||||
|
||||
|
||||
def terminate(*args):
|
||||
try:
|
||||
for arg in args:
|
||||
if arg and isinstance(arg, subprocess.Popen):
|
||||
arg.terminate()
|
||||
arg.wait(timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def request_completion(model_name, prompt, port=TEST_PORT):
|
||||
client = openai.OpenAI(api_key="tensorrt_llm",
|
||||
base_url=f"http://localhost:{port}/v1")
|
||||
return client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=10,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("router", ROUTER_TYPES, indirect=True)
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.timeout(600)
|
||||
async def test_service_discovery(model_name, disagg_server_config,
|
||||
worker_config, router):
|
||||
ctx_worker1 = None
|
||||
gen_worker1 = None
|
||||
disagg_server = None
|
||||
try:
|
||||
# initial cluster, 1 ctx, 1 gen, request should succeed
|
||||
ctx_worker1 = run_ctx_worker(model_name, worker_config, TEST_PORT + 100)
|
||||
gen_worker1 = run_gen_worker(model_name, worker_config, TEST_PORT + 200)
|
||||
disagg_server = run_disagg_server(disagg_server_config, TEST_PORT)
|
||||
await wait_for_disagg_server_ready(TEST_PORT)
|
||||
verify_cluster_info(True, 1, 1)
|
||||
response = request_completion(model_name,
|
||||
"Hello, my name is",
|
||||
port=TEST_PORT)
|
||||
print(response)
|
||||
finally:
|
||||
terminate(ctx_worker1, gen_worker1, disagg_server)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize(
|
||||
"router", ["round_robin"], indirect=True
|
||||
) # use only round_robin to reduce the test time, this router type doesn't matter for this test
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.timeout(600)
|
||||
async def test_minimal_instances(model_name, disagg_server_config,
|
||||
worker_config, router):
|
||||
# the cluster should have at least 2 ctx and 2 gen workers
|
||||
minimal_instances = {
|
||||
"context_servers": 2,
|
||||
"generation_servers": 2,
|
||||
}
|
||||
disagg_server_config["disagg_cluster"][
|
||||
"minimal_instances"] = minimal_instances
|
||||
worker_config["disagg_cluster"]["minimal_instances"] = minimal_instances
|
||||
|
||||
processes = []
|
||||
|
||||
try:
|
||||
processes.append(
|
||||
run_ctx_worker(model_name, worker_config, TEST_PORT + 100))
|
||||
processes.append(
|
||||
run_gen_worker(model_name, worker_config, TEST_PORT + 200))
|
||||
processes.append(run_disagg_server(disagg_server_config, TEST_PORT))
|
||||
await wait_for_worker_ready(TEST_PORT + 100)
|
||||
await wait_for_worker_ready(TEST_PORT + 200)
|
||||
verify_cluster_info(False, 1, 1)
|
||||
# with only 1 ctx and 1 gen worker, the request should fail
|
||||
with pytest.raises(Exception):
|
||||
response = request_completion(model_name,
|
||||
"Hello, my name is",
|
||||
port=TEST_PORT)
|
||||
print(response)
|
||||
|
||||
processes.append(
|
||||
run_ctx_worker(model_name, worker_config, TEST_PORT + 101))
|
||||
processes.append(
|
||||
run_gen_worker(model_name, worker_config, TEST_PORT + 201))
|
||||
await wait_for_disagg_server_ready(TEST_PORT)
|
||||
verify_cluster_info(True, 2, 2)
|
||||
response = request_completion(model_name,
|
||||
"Hello, my name is",
|
||||
port=TEST_PORT)
|
||||
print(response)
|
||||
finally:
|
||||
terminate(*processes)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("router", ROUTER_TYPES, indirect=True)
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.timeout(600)
|
||||
async def test_worker_restart(model_name, disagg_server_config, worker_config,
|
||||
router):
|
||||
ctx_worker1 = None
|
||||
ctx_worker2 = None
|
||||
gen_worker1 = None
|
||||
gen_worker2 = None
|
||||
disagg_server = None
|
||||
|
||||
try:
|
||||
# initial cluster, 1 ctx, 1 gen, request should succeed
|
||||
ctx_worker1 = run_ctx_worker(model_name,
|
||||
worker_config,
|
||||
TEST_PORT + 100,
|
||||
device=0)
|
||||
gen_worker1 = run_gen_worker(model_name,
|
||||
worker_config,
|
||||
TEST_PORT + 200,
|
||||
device=1)
|
||||
disagg_server = run_disagg_server(disagg_server_config, TEST_PORT)
|
||||
await wait_for_disagg_server_ready(TEST_PORT)
|
||||
verify_cluster_info(True, 1, 1)
|
||||
response = request_completion(model_name,
|
||||
"Hello, my name is",
|
||||
port=TEST_PORT)
|
||||
print(response)
|
||||
# kill gen1, the request should fail
|
||||
terminate(gen_worker1)
|
||||
await asyncio.sleep(INACTIVE_TIMEOUT)
|
||||
verify_cluster_info(False, 1, 0)
|
||||
with pytest.raises(Exception):
|
||||
request_completion(model_name, "Hello, my name is", port=TEST_PORT)
|
||||
|
||||
test_prompt = "The capital of France is"
|
||||
|
||||
# add gen2, the request should succeed
|
||||
gen_worker2 = run_gen_worker(model_name,
|
||||
worker_config,
|
||||
TEST_PORT + 201,
|
||||
device=2)
|
||||
await wait_for_worker_ready(TEST_PORT + 201)
|
||||
await asyncio.sleep(INACTIVE_TIMEOUT)
|
||||
verify_cluster_info(True, 1, 1)
|
||||
|
||||
response = request_completion(model_name, test_prompt, port=TEST_PORT)
|
||||
print(response)
|
||||
response_text = response.choices[0].text
|
||||
assert len(response.choices[0].text) >= 1
|
||||
|
||||
# kill ctx1, the request should fail
|
||||
terminate(ctx_worker1)
|
||||
await asyncio.sleep(INACTIVE_TIMEOUT)
|
||||
verify_cluster_info(False, 0, 1)
|
||||
with pytest.raises(Exception):
|
||||
request_completion(model_name, test_prompt, port=TEST_PORT)
|
||||
|
||||
# add ctx2, the request should succeed
|
||||
ctx_worker2 = run_ctx_worker(model_name,
|
||||
worker_config,
|
||||
TEST_PORT + 101,
|
||||
device=3)
|
||||
await wait_for_worker_ready(TEST_PORT + 101)
|
||||
verify_cluster_info(True, 1, 1)
|
||||
|
||||
response = request_completion(model_name, test_prompt, port=TEST_PORT)
|
||||
response_text = response.choices[0].text
|
||||
assert len(response.choices[0].text) >= 1
|
||||
|
||||
# restart ctx1 and gen1 with the same ports, we have 2 ctxs and 2 gens now
|
||||
ctx_worker1 = run_ctx_worker(model_name, worker_config, TEST_PORT + 100)
|
||||
gen_worker1 = run_gen_worker(model_name, worker_config, TEST_PORT + 200)
|
||||
await wait_for_worker_ready(TEST_PORT + 100)
|
||||
await wait_for_worker_ready(TEST_PORT + 200)
|
||||
await asyncio.sleep(INACTIVE_TIMEOUT)
|
||||
verify_cluster_info(True, 2, 2)
|
||||
|
||||
# send 10 requests, the responses will be generated by the different ctx/gen workers (but we can't verify it now)
|
||||
for _ in range(10):
|
||||
response = request_completion(model_name,
|
||||
test_prompt,
|
||||
port=TEST_PORT)
|
||||
assert response.choices[0].text == response_text
|
||||
print(response)
|
||||
finally:
|
||||
terminate(ctx_worker1, ctx_worker2, gen_worker1, gen_worker2,
|
||||
disagg_server)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("router", ["round_robin"], indirect=True)
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
@pytest.mark.timeout(300)
|
||||
async def test_disagg_server_restart(model_name, disagg_server_config,
|
||||
worker_config, router):
|
||||
ctx_worker1 = None
|
||||
gen_worker1 = None
|
||||
disagg_server = None
|
||||
try:
|
||||
# initial cluster, 1 ctx, 1 gen, request should succeed
|
||||
ctx_worker1 = run_ctx_worker(model_name, worker_config, TEST_PORT + 100)
|
||||
gen_worker1 = run_gen_worker(model_name, worker_config, TEST_PORT + 200)
|
||||
disagg_server = run_disagg_server(disagg_server_config, TEST_PORT)
|
||||
await wait_for_disagg_server_ready(TEST_PORT)
|
||||
verify_cluster_info(True, 1, 1)
|
||||
response = request_completion(model_name,
|
||||
"Hello, my name is",
|
||||
port=TEST_PORT)
|
||||
print(response)
|
||||
response_text = response.choices[0].text
|
||||
|
||||
# kill disagg server, the request should fail
|
||||
terminate(disagg_server)
|
||||
await asyncio.sleep(INACTIVE_TIMEOUT)
|
||||
with pytest.raises(Exception):
|
||||
verify_cluster_info(False, 1, 1, expected_code=500)
|
||||
|
||||
# restart disagg server, the request should succeed
|
||||
disagg_server = run_disagg_server(disagg_server_config, TEST_PORT)
|
||||
await wait_for_disagg_server_ready(TEST_PORT)
|
||||
verify_cluster_info(True, 1, 1)
|
||||
response = request_completion(model_name,
|
||||
"Hello, my name is",
|
||||
port=TEST_PORT)
|
||||
print(response)
|
||||
assert response.choices[0].text == response_text
|
||||
|
||||
finally:
|
||||
terminate(disagg_server, ctx_worker1, gen_worker1)
|
||||
@ -91,6 +91,12 @@ l0_dgx_h100:
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=2]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[GSM8K]
|
||||
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_multi_instance[MMLU]
|
||||
- disaggregated/test_auto_scaling.py::test_service_discovery[round_robin]
|
||||
- disaggregated/test_auto_scaling.py::test_service_discovery[load_balancing]
|
||||
- disaggregated/test_auto_scaling.py::test_worker_restart[round_robin]
|
||||
- disaggregated/test_auto_scaling.py::test_worker_restart[load_balancing]
|
||||
- disaggregated/test_auto_scaling.py::test_minimal_instances[round_robin]
|
||||
- disaggregated/test_auto_scaling.py::test_disagg_server_restart[round_robin]
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -35,15 +35,50 @@ def get_yaml_config():
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_yaml_config():
|
||||
config = get_yaml_config()
|
||||
def get_yaml_config_with_disagg_cluster():
|
||||
config = {
|
||||
"hostname": "test_host",
|
||||
"port": 9000,
|
||||
"context_servers": {
|
||||
"max_batch_size": 1,
|
||||
"tensor_parallel_size": 2,
|
||||
"pipeline_parallel_size": 1,
|
||||
},
|
||||
"generation_servers": {
|
||||
"router": {
|
||||
"type": "load_balancing",
|
||||
"use_tokens": False,
|
||||
},
|
||||
"max_batch_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
"pipeline_parallel_size": 1,
|
||||
},
|
||||
"disagg_cluster": {
|
||||
"cluster_uri": "http://test_host:9000",
|
||||
"cluster_name": "test_cluster",
|
||||
"minimal_instances": {
|
||||
"context_servers": 2,
|
||||
"generation_servers": 2,
|
||||
},
|
||||
"heartbeat_interval_sec": 1,
|
||||
"inactive_timeout_sec": 2,
|
||||
},
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_yaml_file(tmp_path):
|
||||
config = get_yaml_config()
|
||||
def sample_yaml_config(request):
|
||||
if request.param == "disagg_cluster":
|
||||
config = get_yaml_config_with_disagg_cluster()
|
||||
else:
|
||||
config = get_yaml_config()
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_yaml_file(sample_yaml_config, tmp_path):
|
||||
config = sample_yaml_config
|
||||
|
||||
yaml_file = tmp_path / "test_config.yaml"
|
||||
with open(yaml_file, "w") as f:
|
||||
@ -51,24 +86,30 @@ def sample_yaml_file(tmp_path):
|
||||
return yaml_file
|
||||
|
||||
|
||||
def verify_disagg_config(config: DisaggServerConfig):
|
||||
def verify_disagg_config(config: DisaggServerConfig,
|
||||
sample_yaml_config: str = ""):
|
||||
assert config.hostname == "test_host"
|
||||
assert config.port == 9000
|
||||
assert config.ctx_router_config.type == "round_robin"
|
||||
assert config.gen_router_config.type == "load_balancing"
|
||||
assert len(config.server_configs) == 3
|
||||
if sample_yaml_config == "":
|
||||
assert len(config.server_configs) == 3
|
||||
|
||||
|
||||
def test_parse_disagg_config_file(sample_yaml_file):
|
||||
@pytest.mark.parametrize("sample_yaml_config", ["disagg_cluster", ""],
|
||||
indirect=True)
|
||||
def test_parse_disagg_config_file(sample_yaml_file, sample_yaml_config):
|
||||
config = parse_disagg_config_file(sample_yaml_file)
|
||||
assert isinstance(config, DisaggServerConfig)
|
||||
verify_disagg_config(config)
|
||||
verify_disagg_config(config, sample_yaml_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sample_yaml_config", ["disagg_cluster", ""],
|
||||
indirect=True)
|
||||
def test_extract_disagg_cfg(sample_yaml_config):
|
||||
config = extract_disagg_cfg(**sample_yaml_config)
|
||||
assert isinstance(config, DisaggServerConfig)
|
||||
verify_disagg_config(config)
|
||||
verify_disagg_config(config, sample_yaml_config)
|
||||
|
||||
|
||||
def test_extract_ctx_gen_cfgs():
|
||||
@ -85,6 +126,7 @@ def test_extract_ctx_gen_cfgs():
|
||||
assert configs[0].instance_num_ranks == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sample_yaml_config", [""], indirect=True)
|
||||
def test_extract_router_config(sample_yaml_config):
|
||||
ctx_server_config = sample_yaml_config["context_servers"]
|
||||
gen_server_config = sample_yaml_config["generation_servers"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user