diff --git a/tensorrt_llm/llmapi/rlhf_utils.py b/tensorrt_llm/llmapi/rlhf_utils.py index ce6eaa5b4f..b49eaa70d5 100644 --- a/tensorrt_llm/llmapi/rlhf_utils.py +++ b/tensorrt_llm/llmapi/rlhf_utils.py @@ -1,4 +1,5 @@ import base64 +import io import pickle # nosec B403 from typing import Optional @@ -10,6 +11,44 @@ from tensorrt_llm._torch.utils import get_device_uuid from tensorrt_llm.logger import logger +class RestrictedUnpickler(pickle.Unpickler): + """Restricted unpickler that only allows safe types. + + This prevents arbitrary code execution by restricting what can be deserialized. + Only allows: list, tuple, str, int, float, bool, bytes, and torch-related types + needed for tensor reconstruction. + """ + + # Allowed modules and their classes + ALLOWED_TYPES = { + "builtins": { + "list", + "tuple", + "str", + "int", + "float", + "bool", + "bytes", + "dict", + "NoneType", + "type", + }, + } + + def find_class(self, module, name): + """Override to restrict which classes can be unpickled.""" + # Check if the module is in our allowed list + is_torch_module = module.startswith("torch") + if is_torch_module or (module in self.ALLOWED_TYPES and name in self.ALLOWED_TYPES[module]): + return super().find_class(module, name) + + raise pickle.UnpicklingError( + f"Global '{module}.{name}' is forbidden for security reasons. " + f"Only basic types (list, tuple, str, int, etc.) and torch tensor types are allowed." + f"Module: {module}, Name: {name}" + ) + + class WorkerExtension: """Worker extension class for extending TensorRT-LLM Ray workers with custom functionality. @@ -61,9 +100,16 @@ class WorkerExtension: serialized_handles = ipc_handles[device_uuid] if isinstance(serialized_handles, str): - # Data is base64-encoded pickled bytes - deserialize it + # Data is base64-encoded pickled bytes - deserialize it using restricted unpickler logger.info("Deserializing base64-encoded weight handles") - all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301 + decoded_data = base64.b64decode(serialized_handles) + all_handles = RestrictedUnpickler(io.BytesIO(decoded_data)).load() + + # Verify the result is a list as expected + if not isinstance(all_handles, list): + raise ValueError( + f"Deserialized data must be a list, got {type(all_handles).__name__} instead" + ) else: # Data is already in the correct format (backward compatibility) all_handles = serialized_handles diff --git a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py index 96e8822612..1fcb865f97 100644 --- a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py +++ b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py @@ -1,3 +1,6 @@ +import base64 +import io +import pickle from typing import Callable, List, Optional import pytest @@ -8,6 +11,7 @@ from utils.llm_data import llm_models_root from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams +from tensorrt_llm.llmapi.rlhf_utils import RestrictedUnpickler class HFModel: @@ -71,6 +75,40 @@ class HFModel: return ret + def get_weight_ipc_handles_serialized( + self, + cuda_device: Optional[List[int]] = None, + weight_filter: Optional[Callable[[str], bool]] = None, + ): + """ + Get base64-encoded serialized IPC handles for model weights. + + Args: + cuda_device: List of CUDA device indices to get weights from + weight_filter: Optional function that takes weight name and returns True if weight should be included + + Returns: + ret: Dictionary mapping device UUIDs to base64-encoded pickled handles + """ + ret = {} + device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device + + for device in device_list: + all_handles = [] + for item in self.all_weights[device]: + name, p = item + # Apply filter if provided + if weight_filter is not None and not weight_filter(name): + continue + handle = reduce_tensor(p) + all_handles.append((name, handle)) + + # Serialize with base64-encoded pickle + serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8") + ret[self.device_uuid[device]] = serialized + + return ret + def generate_batch_incremental( self, original_prompts: List[str], generated_token_ids_list: List[List[int]] ): @@ -153,6 +191,57 @@ def run_generate(llm, hf_model, prompts, sampling_params): return llm_logits, ref_logits +@pytest.mark.parametrize( + "model_dir", + ["Qwen2.5-0.5B-Instruct"], +) +def test_llm_update_weights_with_serialized_handles(model_dir): + """Test LLM update_weights with base64-encoded serialized handles (RestrictedUnpickler).""" + model_dir = str(llm_models_root() / model_dir) + kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) + + hf_model = HFModel(model_dir) + + llm = LLM( + model=model_dir, + ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + tensor_parallel_size=1, + load_format="dummy", + pipeline_parallel_size=1, + kv_cache_config=kv_cache_config, + ) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + ] + + sampling_params = SamplingParams(temperature=0, return_generation_logits=True) + + # Use the serialized format (base64-encoded pickle) + ipc_handles_serialized = hf_model.get_weight_ipc_handles_serialized([0]) + + # Verify the format is correct (should be base64-encoded strings) + for device_uuid, serialized_data in ipc_handles_serialized.items(): + assert isinstance(serialized_data, str), "Should be base64-encoded string" + # Verify it can be decoded + decoded = base64.b64decode(serialized_data) + # Verify it can be deserialized with RestrictedUnpickler + deserialized = RestrictedUnpickler(io.BytesIO(decoded)).load() + assert isinstance(deserialized, list), "Should deserialize to list" + + # Update weights using the serialized format + llm._collective_rpc("update_weights", (ipc_handles_serialized,)) + # Finalize the update weights + llm._collective_rpc("update_weights", (None,)) + + # Verify generation works with updated weights + llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params) + compare_logits(llm_logits, ref_logits) + + print("✓ LLM update_weights with serialized handles (RestrictedUnpickler) works!") + + @pytest.mark.parametrize( "model_dir", ["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],