diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index ff189e3be9..9eb271551a 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -18,7 +18,7 @@ from torch.cuda import device_count from tensorrt_llm import LLM as PyTorchLLM from tensorrt_llm import MultimodalEncoder from tensorrt_llm._tensorrt_engine import LLM -from tensorrt_llm._utils import get_free_port, mpi_rank +from tensorrt_llm._utils import mpi_rank from tensorrt_llm.executor.utils import LlmLauncherEnvs from tensorrt_llm.inputs.multimodal import MultimodalServerConfig from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy, @@ -189,25 +189,12 @@ def launch_server( with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: # If disagg cluster config is provided and port is not specified, try to find a free port, otherwise try to bind to the specified port assert port > 0 or disagg_cluster_config is not None, "Port must be specified if disagg cluster config is not provided" - if port > 0: - port_retries = 1 - else: - port_retries = 100 - port = get_free_port() - while port_retries > 0: - try: - s.bind((host, port)) - break - except OSError as e: - port_retries -= 1 - if port_retries == 0: - raise RuntimeError( - f"Failed to bind socket to {host}:{port}: {e}") - else: - logger.warning( - f"Failed to bind socket to {host}:{port}: {e}, retrying {port_retries}..." - ) - port = get_free_port() + try: + s.bind((host, port)) + if port == 0: + port = s.getsockname()[1] + except OSError as e: + raise RuntimeError(f"Failed to bind socket to {host}:{port}: {e}") if backend == 'pytorch': llm_args.pop("build_config", None)