[https://nvbugs/5649010][fix] use 0 port as arbitrary port when disagg service discovery is enabled (#10383)

Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
This commit is contained in:
Lizhi Zhou 2026-01-05 09:40:40 +08:00 committed by GitHub
parent 0517b62789
commit 82c1ba84a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 40 deletions

View File

@ -18,7 +18,7 @@ from torch.cuda import device_count
from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm import MultimodalEncoder
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm._utils import get_free_port, mpi_rank
from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
@ -180,10 +180,27 @@ def launch_server(
backend = llm_args["backend"]
model = llm_args["model"]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((host, port))
except OSError as e:
raise RuntimeError(f"Failed to bind socket to {host}:{port}: {e}")
# If disagg cluster config is provided and port is not specified, try to find a free port, otherwise try to bind to the specified port
assert port > 0 or disagg_cluster_config is not None, "Port must be specified if disagg cluster config is not provided"
if port > 0:
port_retries = 1
else:
port_retries = 100
port = get_free_port()
while port_retries > 0:
try:
s.bind((host, port))
break
except OSError as e:
port_retries -= 1
if port_retries == 0:
raise RuntimeError(
f"Failed to bind socket to {host}:{port}: {e}")
else:
logger.warning(
f"Failed to bind socket to {host}:{port}: {e}, retrying {port_retries}..."
)
port = get_free_port()
if backend == 'pytorch':
llm_args.pop("build_config", None)

View File

@ -23,24 +23,6 @@ INACTIVE_TIMEOUT = 2
CHECK_STATUS_INTERVAL = 3
ROUTER_TYPES = ["round_robin", "load_balancing", "kv_cache_aware"]
USED_PORTS = set()
# get_free_port doesn't guarantee that consecutive calls will return different ports
# if no server is bound to the port immediately after the call
def get_free_unused_port():
global USED_PORTS
max_attempts = 100
for _ in range(max_attempts):
port = get_free_port()
assert port > 0, f"get_free_port returned {port}"
if port not in USED_PORTS:
USED_PORTS.add(port)
return port
else:
logger.info(f"Port {port} is already used, trying another one")
raise Exception(
f"Failed to find a free unused port after {max_attempts} attempts")
@pytest.fixture
@ -53,7 +35,7 @@ def model_name():
@pytest.fixture
def disagg_port():
return get_free_unused_port()
return get_free_port()
@pytest.fixture
@ -145,8 +127,6 @@ def _run_worker(model_name,
work_dir,
device=-1,
save_log=False):
if port == 0:
port = get_free_unused_port()
worker_config_path = os.path.join(work_dir, f"{role}_{port}_config.yaml")
with open(worker_config_path, "w+") as f:
yaml.dump(worker_config, f)
@ -187,6 +167,7 @@ def _run_worker(model_name,
port=port)
# Use 0 as the port and provide disagg_cluster_config to let the worker choose a free port
def run_ctx_worker(model_name, ctx_worker_config, work_dir, port=0, device=0):
return _run_worker(model_name, ctx_worker_config, "ctx", port, work_dir,
device)
@ -246,19 +227,38 @@ def periodic_check(timeout=300, interval=3):
return decorator
@periodic_check(timeout=300, interval=3)
async def wait_for_disagg_server_ready(port):
async def _wait_for_disagg_server_status(port,
ready=True,
min_ctx_workers=-1,
min_gen_workers=-1):
info_resp = requests.get(f"http://localhost:{port}/cluster_info")
logger.info(
f"Waiting for disagg server {port} to be ready: {info_resp.json()}")
if info_resp.status_code == 200:
info = info_resp.json()
return info["is_ready"]
else:
logger.info(f"Failed to get cluster info: {info_resp.status_code}")
if ready:
return info["is_ready"]
else:
return len(info["current_workers"]
["context_servers"]) >= min_ctx_workers and len(
info["current_workers"]
["generation_servers"]) >= min_gen_workers
return False
@periodic_check(timeout=300, interval=3)
async def wait_for_disagg_server_ready(port):
return await _wait_for_disagg_server_status(port, True)
@periodic_check(timeout=300, interval=3)
async def wait_for_disagg_server_status(port,
min_ctx_workers=-1,
min_gen_workers=-1):
return await _wait_for_disagg_server_status(port, False, min_ctx_workers,
min_gen_workers)
@periodic_check(timeout=300, interval=3)
async def wait_for_worker_ready(port):
logger.info(f"Waiting for worker {port} to be ready")
@ -314,9 +314,6 @@ def terminate(*args, show_log_lines=30, release_port=True):
if arg.log_file:
arg.log_file.close()
arg.log_file = None
if release_port:
global USED_PORTS
USED_PORTS.discard(arg.port)
except Exception:
print(f"Failed to terminate process {arg.process.pid}")
else:
@ -399,8 +396,7 @@ async def test_minimal_instances(model_name, disagg_server_config,
gen_worker1 = run_gen_worker(model_name, worker_config, work_dir)
disagg_server = run_disagg_server(disagg_server_config, work_dir,
disagg_port)
await wait_for_worker_ready(ctx_worker1.port)
await wait_for_worker_ready(gen_worker1.port)
await wait_for_disagg_server_status(disagg_port, 1, 1)
verify_cluster_info(False, 1, 1, port=disagg_port)
# with only 1 ctx and 1 gen worker, the request should fail
with pytest.raises(Exception):
@ -470,7 +466,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
work_dir,
port=0,
device=2)
await wait_for_worker_ready(gen_worker2.port)
await wait_for_disagg_server_status(disagg_port, 1, 1)
await asyncio.sleep(CHECK_STATUS_INTERVAL)
verify_cluster_info(True, 1, 1, port=disagg_port)
@ -492,7 +488,8 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
work_dir,
port=0,
device=3)
await wait_for_worker_ready(ctx_worker2.port)
await wait_for_disagg_server_status(disagg_port, 1, 1)
await asyncio.sleep(CHECK_STATUS_INTERVAL)
verify_cluster_info(True, 1, 1, port=disagg_port)
response = request_completion(model_name, test_prompt, port=disagg_port)
@ -510,8 +507,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
work_dir,
port=0,
device=1)
await wait_for_worker_ready(ctx_worker1.port)
await wait_for_worker_ready(gen_worker1.port)
await wait_for_disagg_server_status(disagg_port, 2, 2)
await asyncio.sleep(CHECK_STATUS_INTERVAL)
verify_cluster_info(True, 2, 2, port=disagg_port)