diff --git a/tensorrt_llm/executor/utils.py b/tensorrt_llm/executor/utils.py index e52ea481fb..9801ee1cd2 100644 --- a/tensorrt_llm/executor/utils.py +++ b/tensorrt_llm/executor/utils.py @@ -55,9 +55,9 @@ def create_mpi_comm_session( logger_debug( f"Using RemoteMpiPoolSessionClient to bind to external MPI processes at {get_spawn_proxy_process_ipc_addr_env()}\n", "yellow") - get_spawn_proxy_process_ipc_hmac_key_env() + hmac_key = get_spawn_proxy_process_ipc_hmac_key_env() return RemoteMpiCommSessionClient( - addr=get_spawn_proxy_process_ipc_addr_env()) + addr=get_spawn_proxy_process_ipc_addr_env(), hmac_key=hmac_key) else: logger_debug( f"Using MpiCommSession to bind to external MPI processes\n", diff --git a/tensorrt_llm/llmapi/mgmn_leader_node.py b/tensorrt_llm/llmapi/mgmn_leader_node.py index 85f8561ebe..2b1d11b0cc 100644 --- a/tensorrt_llm/llmapi/mgmn_leader_node.py +++ b/tensorrt_llm/llmapi/mgmn_leader_node.py @@ -9,7 +9,9 @@ import zmq from tensorrt_llm._utils import global_mpi_rank, mpi_world_size from tensorrt_llm.executor.ipc import ZeroMqQueue -from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_addr_env +from tensorrt_llm.executor.utils import ( + get_spawn_proxy_process_ipc_addr_env, + get_spawn_proxy_process_ipc_hmac_key_env) from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionServer from tensorrt_llm.llmapi.utils import logger_debug @@ -23,6 +25,7 @@ def launch_server_main(sub_comm=None): comm=sub_comm, n_workers=num_ranks, addr=get_spawn_proxy_process_ipc_addr_env(), + hmac_key=get_spawn_proxy_process_ipc_hmac_key_env(), is_comm=True) logger_debug( f"MPI Comm Server started at {get_spawn_proxy_process_ipc_addr_env()}") @@ -32,8 +35,9 @@ def launch_server_main(sub_comm=None): def stop_server_main(): - queue = ZeroMqQueue((get_spawn_proxy_process_ipc_addr_env(), None), - use_hmac_encryption=False, + hmac_key = get_spawn_proxy_process_ipc_hmac_key_env() + queue = ZeroMqQueue((get_spawn_proxy_process_ipc_addr_env(), hmac_key), + use_hmac_encryption=bool(hmac_key), is_server=False, socket_type=zmq.PAIR) diff --git a/tests/unittest/llmapi/_run_mpi_comm_task.py b/tests/unittest/llmapi/_run_mpi_comm_task.py index b60b7a1efd..94ca4d3865 100644 --- a/tests/unittest/llmapi/_run_mpi_comm_task.py +++ b/tests/unittest/llmapi/_run_mpi_comm_task.py @@ -3,6 +3,7 @@ from typing import Literal import click +from tensorrt_llm.executor.utils import get_spawn_proxy_process_ipc_hmac_key_env from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient from tensorrt_llm.llmapi.utils import print_colored @@ -15,8 +16,9 @@ def main(task_type: Literal["submit", "submit_sync"]): tasks = [0] assert os.environ[ 'TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set" + hmac_key = get_spawn_proxy_process_ipc_hmac_key_env() client = RemoteMpiCommSessionClient( - os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR']) + os.environ['TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR'], hmac_key=hmac_key) for task in tasks: if task_type == "submit": client.submit(print_colored, f"{task}\n", "green") diff --git a/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py b/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py index 5b50df94f2..440d07149c 100644 --- a/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py +++ b/tests/unittest/llmapi/_run_multi_mpi_comm_tasks.py @@ -3,7 +3,8 @@ from typing import Literal import click -from tensorrt_llm.executor.utils import LlmLauncherEnvs +from tensorrt_llm.executor.utils import ( + LlmLauncherEnvs, get_spawn_proxy_process_ipc_hmac_key_env) from tensorrt_llm.llmapi.mpi_session import RemoteMpiCommSessionClient from tensorrt_llm.llmapi.utils import print_colored @@ -13,8 +14,10 @@ def run_task(task_type: Literal["submit", "submit_sync"]): assert os.environ[ LlmLauncherEnvs. TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR] is not None, "TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR is not set" + hmac_key = get_spawn_proxy_process_ipc_hmac_key_env() client = RemoteMpiCommSessionClient( - os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR]) + os.environ[LlmLauncherEnvs.TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR], + hmac_key=hmac_key) for task in tasks: if task_type == "submit":