mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] enable hmac in RPC (#9745)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
2645a78f34
commit
e4c707845f
@ -82,6 +82,8 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
|
||||
is_llm_executor=is_llm_executor)
|
||||
|
||||
self.init_rpc_executor()
|
||||
# Inject the generated HMAC key into worker_kwargs for workers
|
||||
worker_kwargs['hmac_key'] = self.hmac_key
|
||||
worker_kwargs['rpc_addr'] = self.rpc_addr
|
||||
self.create_workers(RayGPUWorker, worker_kwargs)
|
||||
self.setup_engine_remote()
|
||||
|
||||
@ -168,6 +168,7 @@ class RayGPUWorker(RpcWorkerMixin, BaseWorker):
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
rpc_addr: Optional[str] = None,
|
||||
hmac_key: Optional[bytes] = None,
|
||||
) -> None:
|
||||
global logger
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -191,7 +192,7 @@ class RayGPUWorker(RpcWorkerMixin, BaseWorker):
|
||||
if rpc_addr is None:
|
||||
raise RuntimeError(
|
||||
"RPC mode enabled but no rpc_addr provided to RayGPUWorker")
|
||||
self.init_rpc_worker(self.global_rank, rpc_addr)
|
||||
self.init_rpc_worker(self.global_rank, rpc_addr, hmac_key)
|
||||
self.start_rpc_server()
|
||||
|
||||
def setup_engine(self):
|
||||
|
||||
@ -108,7 +108,8 @@ class RPCClient:
|
||||
self._client_socket = ZeroMqQueue(address=(address, hmac_key),
|
||||
is_server=False,
|
||||
is_async=True,
|
||||
use_hmac_encryption=False,
|
||||
use_hmac_encryption=hmac_key
|
||||
is not None,
|
||||
socket_type=socket_type,
|
||||
name="rpc_client")
|
||||
self._pending_futures = {}
|
||||
|
||||
@ -108,7 +108,8 @@ class RPCServer:
|
||||
self._client_socket = ZeroMqQueue(address=(address, self._hmac_key),
|
||||
is_server=True,
|
||||
is_async=True,
|
||||
use_hmac_encryption=False,
|
||||
use_hmac_encryption=self._hmac_key
|
||||
is not None,
|
||||
socket_type=socket_type,
|
||||
name="rpc_server")
|
||||
logger.info(f"RPCServer is bound to {self._address}")
|
||||
|
||||
@ -48,6 +48,8 @@ class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor):
|
||||
|
||||
self._create_mpi_session(model_world_size, mpi_session)
|
||||
|
||||
# Inject the generated HMAC key into worker_kwargs for workers
|
||||
worker_kwargs['hmac_key'] = self.hmac_key
|
||||
self.worker_kwargs = worker_kwargs
|
||||
|
||||
self.launch_workers()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
@ -29,7 +30,8 @@ class RpcExecutorMixin:
|
||||
|
||||
def init_rpc_executor(self):
|
||||
self.rpc_addr = get_unique_ipc_addr()
|
||||
self.rpc_client = RPCClient(self.rpc_addr)
|
||||
self.hmac_key = os.urandom(32)
|
||||
self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key)
|
||||
|
||||
self._results = {}
|
||||
self._shutdown_event = threading.Event()
|
||||
|
||||
@ -155,7 +155,10 @@ class RpcWorker(RpcWorkerMixin, BaseWorker):
|
||||
color="yellow")
|
||||
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
|
||||
# Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async.
|
||||
rpc_server = RPCServer(worker, num_workers=worker.num_workers)
|
||||
hmac_key = kwargs.get("hmac_key")
|
||||
rpc_server = RPCServer(worker,
|
||||
num_workers=worker.num_workers,
|
||||
hmac_key=hmac_key)
|
||||
rpc_server.bind(rpc_addr)
|
||||
rpc_server.start()
|
||||
logger_debug(f"[worker] RPC server {mpi_rank()} is started",
|
||||
|
||||
@ -25,10 +25,11 @@ class RpcWorkerMixin:
|
||||
# This can be overridden by setting num_workers in the inheriting class
|
||||
NUM_WORKERS = 6
|
||||
|
||||
def init_rpc_worker(self, rank: int, rpc_addr: Optional[str]):
|
||||
def init_rpc_worker(self, rank: int, rpc_addr: Optional[str], hmac_key: Optional[bytes] = None):
|
||||
if rpc_addr is None:
|
||||
raise RuntimeError("RPC mode enabled but no rpc_addr provided to worker")
|
||||
|
||||
self.hmac_key = hmac_key
|
||||
self.rank = rank
|
||||
self.shutdown_event = Event()
|
||||
self._response_queue = Queue()
|
||||
@ -41,7 +42,7 @@ class RpcWorkerMixin:
|
||||
if self.rank == 0:
|
||||
# Use num_workers if set on the instance, otherwise use class default
|
||||
num_workers = getattr(self, "num_workers", RpcWorkerMixin.NUM_WORKERS)
|
||||
self.rpc_server = RPCServer(self, num_workers=num_workers)
|
||||
self.rpc_server = RPCServer(self, num_workers=num_workers, hmac_key=self.hmac_key)
|
||||
self.rpc_server.bind(self.rpc_addr)
|
||||
self.rpc_server.start()
|
||||
|
||||
|
||||
@ -95,6 +95,43 @@ class TestRpcProxy:
|
||||
assert similar(tokenizer.decode(result.outputs[0].token_ids),
|
||||
'E F G H I J K L')
|
||||
|
||||
def test_hmac_key_generation(self):
|
||||
"""Test that HMAC key is automatically generated and properly propagated."""
|
||||
tokenizer = TransformersTokenizer.from_pretrained(model_path)
|
||||
prompt = "A B C D"
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
max_tokens = 8
|
||||
|
||||
with self.create_proxy(tp_size=1) as proxy:
|
||||
assert proxy.hmac_key is not None, "HMAC key should be generated"
|
||||
assert len(
|
||||
proxy.hmac_key
|
||||
) == 32, f"HMAC key should be 32 bytes, got {len(proxy.hmac_key)}"
|
||||
|
||||
# Verify key is properly stored in worker_kwargs
|
||||
assert 'hmac_key' in proxy.worker_kwargs, "HMAC key should be in worker_kwargs"
|
||||
assert proxy.worker_kwargs[
|
||||
'hmac_key'] is not None, "HMAC key in worker_kwargs should not be None"
|
||||
|
||||
# Verify both references point to the same key object
|
||||
assert proxy.hmac_key is proxy.worker_kwargs['hmac_key'], \
|
||||
"HMAC key should be the same object in both locations"
|
||||
|
||||
logger_debug(
|
||||
f"[Test] HMAC key verified: length={len(proxy.hmac_key)} bytes",
|
||||
color="green")
|
||||
|
||||
# Verify RPC communication works with the generated key
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
result = proxy.generate(prompt_token_ids, sampling_params)
|
||||
assert similar(
|
||||
tokenizer.decode(result.outputs[0].token_ids), 'E F G H I J K L'
|
||||
), "Generation should work with auto-generated HMAC key"
|
||||
|
||||
logger_debug(
|
||||
f"[Test] HMAC key test passed: RPC communication successful",
|
||||
color="green")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestRpcProxy().test_tp1(20)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user