mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5649010][fix] use 0 port as arbitrary port when disagg service discovery is enabled (#10383)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
parent
0517b62789
commit
82c1ba84a7
@ -18,7 +18,7 @@ from torch.cuda import device_count
|
||||
from tensorrt_llm import LLM as PyTorchLLM
|
||||
from tensorrt_llm import MultimodalEncoder
|
||||
from tensorrt_llm._tensorrt_engine import LLM
|
||||
from tensorrt_llm._utils import mpi_rank
|
||||
from tensorrt_llm._utils import get_free_port, mpi_rank
|
||||
from tensorrt_llm.executor.utils import LlmLauncherEnvs
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
|
||||
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
|
||||
@ -180,10 +180,27 @@ def launch_server(
|
||||
backend = llm_args["backend"]
|
||||
model = llm_args["model"]
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind((host, port))
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to bind socket to {host}:{port}: {e}")
|
||||
# If disagg cluster config is provided and port is not specified, try to find a free port, otherwise try to bind to the specified port
|
||||
assert port > 0 or disagg_cluster_config is not None, "Port must be specified if disagg cluster config is not provided"
|
||||
if port > 0:
|
||||
port_retries = 1
|
||||
else:
|
||||
port_retries = 100
|
||||
port = get_free_port()
|
||||
while port_retries > 0:
|
||||
try:
|
||||
s.bind((host, port))
|
||||
break
|
||||
except OSError as e:
|
||||
port_retries -= 1
|
||||
if port_retries == 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to bind socket to {host}:{port}: {e}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to bind socket to {host}:{port}: {e}, retrying {port_retries}..."
|
||||
)
|
||||
port = get_free_port()
|
||||
|
||||
if backend == 'pytorch':
|
||||
llm_args.pop("build_config", None)
|
||||
|
||||
@ -23,24 +23,6 @@ INACTIVE_TIMEOUT = 2
|
||||
CHECK_STATUS_INTERVAL = 3
|
||||
|
||||
ROUTER_TYPES = ["round_robin", "load_balancing", "kv_cache_aware"]
|
||||
USED_PORTS = set()
|
||||
|
||||
|
||||
# get_free_port doesn't guarantee that consecutive calls will return different ports
|
||||
# if no server is bound to the port immediately after the call
|
||||
def get_free_unused_port():
|
||||
global USED_PORTS
|
||||
max_attempts = 100
|
||||
for _ in range(max_attempts):
|
||||
port = get_free_port()
|
||||
assert port > 0, f"get_free_port returned {port}"
|
||||
if port not in USED_PORTS:
|
||||
USED_PORTS.add(port)
|
||||
return port
|
||||
else:
|
||||
logger.info(f"Port {port} is already used, trying another one")
|
||||
raise Exception(
|
||||
f"Failed to find a free unused port after {max_attempts} attempts")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -53,7 +35,7 @@ def model_name():
|
||||
|
||||
@pytest.fixture
|
||||
def disagg_port():
|
||||
return get_free_unused_port()
|
||||
return get_free_port()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -145,8 +127,6 @@ def _run_worker(model_name,
|
||||
work_dir,
|
||||
device=-1,
|
||||
save_log=False):
|
||||
if port == 0:
|
||||
port = get_free_unused_port()
|
||||
worker_config_path = os.path.join(work_dir, f"{role}_{port}_config.yaml")
|
||||
with open(worker_config_path, "w+") as f:
|
||||
yaml.dump(worker_config, f)
|
||||
@ -187,6 +167,7 @@ def _run_worker(model_name,
|
||||
port=port)
|
||||
|
||||
|
||||
# Use 0 as the port and provide disagg_cluster_config to let the worker choose a free port
|
||||
def run_ctx_worker(model_name, ctx_worker_config, work_dir, port=0, device=0):
|
||||
return _run_worker(model_name, ctx_worker_config, "ctx", port, work_dir,
|
||||
device)
|
||||
@ -246,19 +227,38 @@ def periodic_check(timeout=300, interval=3):
|
||||
return decorator
|
||||
|
||||
|
||||
@periodic_check(timeout=300, interval=3)
|
||||
async def wait_for_disagg_server_ready(port):
|
||||
async def _wait_for_disagg_server_status(port,
|
||||
ready=True,
|
||||
min_ctx_workers=-1,
|
||||
min_gen_workers=-1):
|
||||
info_resp = requests.get(f"http://localhost:{port}/cluster_info")
|
||||
logger.info(
|
||||
f"Waiting for disagg server {port} to be ready: {info_resp.json()}")
|
||||
if info_resp.status_code == 200:
|
||||
info = info_resp.json()
|
||||
return info["is_ready"]
|
||||
else:
|
||||
logger.info(f"Failed to get cluster info: {info_resp.status_code}")
|
||||
if ready:
|
||||
return info["is_ready"]
|
||||
else:
|
||||
return len(info["current_workers"]
|
||||
["context_servers"]) >= min_ctx_workers and len(
|
||||
info["current_workers"]
|
||||
["generation_servers"]) >= min_gen_workers
|
||||
return False
|
||||
|
||||
|
||||
@periodic_check(timeout=300, interval=3)
|
||||
async def wait_for_disagg_server_ready(port):
|
||||
return await _wait_for_disagg_server_status(port, True)
|
||||
|
||||
|
||||
@periodic_check(timeout=300, interval=3)
|
||||
async def wait_for_disagg_server_status(port,
|
||||
min_ctx_workers=-1,
|
||||
min_gen_workers=-1):
|
||||
return await _wait_for_disagg_server_status(port, False, min_ctx_workers,
|
||||
min_gen_workers)
|
||||
|
||||
|
||||
@periodic_check(timeout=300, interval=3)
|
||||
async def wait_for_worker_ready(port):
|
||||
logger.info(f"Waiting for worker {port} to be ready")
|
||||
@ -314,9 +314,6 @@ def terminate(*args, show_log_lines=30, release_port=True):
|
||||
if arg.log_file:
|
||||
arg.log_file.close()
|
||||
arg.log_file = None
|
||||
if release_port:
|
||||
global USED_PORTS
|
||||
USED_PORTS.discard(arg.port)
|
||||
except Exception:
|
||||
print(f"Failed to terminate process {arg.process.pid}")
|
||||
else:
|
||||
@ -399,8 +396,7 @@ async def test_minimal_instances(model_name, disagg_server_config,
|
||||
gen_worker1 = run_gen_worker(model_name, worker_config, work_dir)
|
||||
disagg_server = run_disagg_server(disagg_server_config, work_dir,
|
||||
disagg_port)
|
||||
await wait_for_worker_ready(ctx_worker1.port)
|
||||
await wait_for_worker_ready(gen_worker1.port)
|
||||
await wait_for_disagg_server_status(disagg_port, 1, 1)
|
||||
verify_cluster_info(False, 1, 1, port=disagg_port)
|
||||
# with only 1 ctx and 1 gen worker, the request should fail
|
||||
with pytest.raises(Exception):
|
||||
@ -470,7 +466,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
|
||||
work_dir,
|
||||
port=0,
|
||||
device=2)
|
||||
await wait_for_worker_ready(gen_worker2.port)
|
||||
await wait_for_disagg_server_status(disagg_port, 1, 1)
|
||||
await asyncio.sleep(CHECK_STATUS_INTERVAL)
|
||||
verify_cluster_info(True, 1, 1, port=disagg_port)
|
||||
|
||||
@ -492,7 +488,8 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
|
||||
work_dir,
|
||||
port=0,
|
||||
device=3)
|
||||
await wait_for_worker_ready(ctx_worker2.port)
|
||||
await wait_for_disagg_server_status(disagg_port, 1, 1)
|
||||
await asyncio.sleep(CHECK_STATUS_INTERVAL)
|
||||
verify_cluster_info(True, 1, 1, port=disagg_port)
|
||||
|
||||
response = request_completion(model_name, test_prompt, port=disagg_port)
|
||||
@ -510,8 +507,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
|
||||
work_dir,
|
||||
port=0,
|
||||
device=1)
|
||||
await wait_for_worker_ready(ctx_worker1.port)
|
||||
await wait_for_worker_ready(gen_worker1.port)
|
||||
await wait_for_disagg_server_status(disagg_port, 2, 2)
|
||||
await asyncio.sleep(CHECK_STATUS_INTERVAL)
|
||||
verify_cluster_info(True, 2, 2, port=disagg_port)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user