[None][fix] enable hmac in RPC (#9745)

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
Yan Chunwei 2025-12-07 08:24:46 +08:00 committed by GitHub
parent 2645a78f34
commit e4c707845f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 57 additions and 7 deletions

View File

@ -82,6 +82,8 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
is_llm_executor=is_llm_executor)
self.init_rpc_executor()
# Inject the generated HMAC key into worker_kwargs for workers
worker_kwargs['hmac_key'] = self.hmac_key
worker_kwargs['rpc_addr'] = self.rpc_addr
self.create_workers(RayGPUWorker, worker_kwargs)
self.setup_engine_remote()

View File

@ -168,6 +168,7 @@ class RayGPUWorker(RpcWorkerMixin, BaseWorker):
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
rpc_addr: Optional[str] = None,
hmac_key: Optional[bytes] = None,
) -> None:
global logger
from tensorrt_llm.logger import logger
@ -191,7 +192,7 @@ class RayGPUWorker(RpcWorkerMixin, BaseWorker):
if rpc_addr is None:
raise RuntimeError(
"RPC mode enabled but no rpc_addr provided to RayGPUWorker")
self.init_rpc_worker(self.global_rank, rpc_addr)
self.init_rpc_worker(self.global_rank, rpc_addr, hmac_key)
self.start_rpc_server()
def setup_engine(self):

View File

@ -108,7 +108,8 @@ class RPCClient:
self._client_socket = ZeroMqQueue(address=(address, hmac_key),
is_server=False,
is_async=True,
use_hmac_encryption=False,
use_hmac_encryption=hmac_key
is not None,
socket_type=socket_type,
name="rpc_client")
self._pending_futures = {}

View File

@ -108,7 +108,8 @@ class RPCServer:
self._client_socket = ZeroMqQueue(address=(address, self._hmac_key),
is_server=True,
is_async=True,
use_hmac_encryption=False,
use_hmac_encryption=self._hmac_key
is not None,
socket_type=socket_type,
name="rpc_server")
logger.info(f"RPCServer is bound to {self._address}")

View File

@ -48,6 +48,8 @@ class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor):
self._create_mpi_session(model_world_size, mpi_session)
# Inject the generated HMAC key into worker_kwargs for workers
worker_kwargs['hmac_key'] = self.hmac_key
self.worker_kwargs = worker_kwargs
self.launch_workers()

View File

@ -1,6 +1,7 @@
import asyncio
import atexit
import json
import os
import threading
from typing import Callable, List, Optional
@ -29,7 +30,8 @@ class RpcExecutorMixin:
def init_rpc_executor(self):
self.rpc_addr = get_unique_ipc_addr()
self.rpc_client = RPCClient(self.rpc_addr)
self.hmac_key = os.urandom(32)
self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key)
self._results = {}
self._shutdown_event = threading.Event()

View File

@ -155,7 +155,10 @@ class RpcWorker(RpcWorkerMixin, BaseWorker):
color="yellow")
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
# Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async.
rpc_server = RPCServer(worker, num_workers=worker.num_workers)
hmac_key = kwargs.get("hmac_key")
rpc_server = RPCServer(worker,
num_workers=worker.num_workers,
hmac_key=hmac_key)
rpc_server.bind(rpc_addr)
rpc_server.start()
logger_debug(f"[worker] RPC server {mpi_rank()} is started",

View File

@ -25,10 +25,11 @@ class RpcWorkerMixin:
# This can be overridden by setting num_workers in the inheriting class
NUM_WORKERS = 6
def init_rpc_worker(self, rank: int, rpc_addr: Optional[str]):
def init_rpc_worker(self, rank: int, rpc_addr: Optional[str], hmac_key: Optional[bytes] = None):
if rpc_addr is None:
raise RuntimeError("RPC mode enabled but no rpc_addr provided to worker")
self.hmac_key = hmac_key
self.rank = rank
self.shutdown_event = Event()
self._response_queue = Queue()
@ -41,7 +42,7 @@ class RpcWorkerMixin:
if self.rank == 0:
# Use num_workers if set on the instance, otherwise use class default
num_workers = getattr(self, "num_workers", RpcWorkerMixin.NUM_WORKERS)
self.rpc_server = RPCServer(self, num_workers=num_workers)
self.rpc_server = RPCServer(self, num_workers=num_workers, hmac_key=self.hmac_key)
self.rpc_server.bind(self.rpc_addr)
self.rpc_server.start()

View File

@ -95,6 +95,43 @@ class TestRpcProxy:
assert similar(tokenizer.decode(result.outputs[0].token_ids),
'E F G H I J K L')
def test_hmac_key_generation(self):
"""Test that HMAC key is automatically generated and properly propagated."""
tokenizer = TransformersTokenizer.from_pretrained(model_path)
prompt = "A B C D"
prompt_token_ids = tokenizer.encode(prompt)
max_tokens = 8
with self.create_proxy(tp_size=1) as proxy:
assert proxy.hmac_key is not None, "HMAC key should be generated"
assert len(
proxy.hmac_key
) == 32, f"HMAC key should be 32 bytes, got {len(proxy.hmac_key)}"
# Verify key is properly stored in worker_kwargs
assert 'hmac_key' in proxy.worker_kwargs, "HMAC key should be in worker_kwargs"
assert proxy.worker_kwargs[
'hmac_key'] is not None, "HMAC key in worker_kwargs should not be None"
# Verify both references point to the same key object
assert proxy.hmac_key is proxy.worker_kwargs['hmac_key'], \
"HMAC key should be the same object in both locations"
logger_debug(
f"[Test] HMAC key verified: length={len(proxy.hmac_key)} bytes",
color="green")
# Verify RPC communication works with the generated key
sampling_params = SamplingParams(max_tokens=max_tokens)
result = proxy.generate(prompt_token_ids, sampling_params)
assert similar(
tokenizer.decode(result.outputs[0].token_ids), 'E F G H I J K L'
), "Generation should work with auto-generated HMAC key"
logger_debug(
f"[Test] HMAC key test passed: RPC communication successful",
color="green")
if __name__ == "__main__":
TestRpcProxy().test_tp1(20)