diff --git a/tensorrt_llm/llmapi/rlhf_utils.py b/tensorrt_llm/llmapi/rlhf_utils.py index ce6eaa5b4f..52219f9985 100644 --- a/tensorrt_llm/llmapi/rlhf_utils.py +++ b/tensorrt_llm/llmapi/rlhf_utils.py @@ -1,9 +1,9 @@ import base64 -import pickle # nosec B403 from typing import Optional import torch +from tensorrt_llm import serialization from tensorrt_llm._ray_utils import control_action_decorator from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer from tensorrt_llm._torch.utils import get_device_uuid @@ -62,8 +62,35 @@ class WorkerExtension: serialized_handles = ipc_handles[device_uuid] if isinstance(serialized_handles, str): # Data is base64-encoded pickled bytes - deserialize it + # using restricted unpickler from tensorrt_llm.serialization logger.info("Deserializing base64-encoded weight handles") - all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301 + decoded_data = base64.b64decode(serialized_handles) + # Allow basic builtins and all torch modules + approved_imports = { + "builtins": [ + "list", + "tuple", + "str", + "int", + "float", + "bool", + "bytes", + "dict", + "NoneType", + "type", + ], + } + all_handles = serialization.loads( + decoded_data, + approved_imports=approved_imports, + approved_module_patterns=[r"^torch.*"], + ) + + # Verify the result is a list as expected + if not isinstance(all_handles, list): + raise ValueError( + f"Deserialized data must be a list, got {type(all_handles).__name__} instead" + ) else: # Data is already in the correct format (backward compatibility) all_handles = serialized_handles diff --git a/tensorrt_llm/serialization.py b/tensorrt_llm/serialization.py index 4045df3617..ac303fec9f 100644 --- a/tensorrt_llm/serialization.py +++ b/tensorrt_llm/serialization.py @@ -2,6 +2,7 @@ import io # pickle is not secure, but but this whole file is a wrapper to make it # possible to mitigate the primary risk of code injection via pickle. import pickle # nosec B403 +import re from functools import partial # This is an example class (white list) to showcase how to guard serialization with approved classes. @@ -126,19 +127,31 @@ def register_approved_class(obj): class Unpickler(pickle.Unpickler): - def __init__(self, *args, approved_imports={}, **kwargs): + def __init__(self, + *args, + approved_imports={}, + approved_module_patterns=None, + **kwargs): super().__init__(*args, **kwargs) self.approved_imports = approved_imports + self.approved_module_patterns = approved_module_patterns or [] # only import approved classes, this is the security boundary. def find_class(self, module, name): - if name not in self.approved_imports.get(module, []): - # If this is triggered when it shouldn't be, then the module - # and class should be added to the approved_imports. If the class - # is being used as part of a routine scenario, then it should be added - # to the appropriate base classes above. - raise ValueError(f"Import {module} | {name} is not allowed") - return super().find_class(module, name) + # Check exact match in approved_imports + if name in self.approved_imports.get(module, []): + return super().find_class(module, name) + + # Check regex pattern match in approved_module_patterns + for pattern in self.approved_module_patterns: + if re.match(pattern, module): + return super().find_class(module, name) + + # If this is triggered when it shouldn't be, then the module + # and class should be added to the approved_imports. If the class + # is being used as part of a routine scenario, then it should be added + # to the appropriate base classes above. + raise ValueError(f"Import {module} | {name} is not allowed") # these are taken from the pickle module to allow for this to be a drop in replacement @@ -156,13 +169,15 @@ def load(file, encoding="ASCII", errors="strict", buffers=None, - approved_imports={}): + approved_imports={}, + approved_module_patterns=None): return Unpickler(file, fix_imports=fix_imports, buffers=buffers, encoding=encoding, errors=errors, - approved_imports=approved_imports).load() + approved_imports=approved_imports, + approved_module_patterns=approved_module_patterns).load() def loads(s, @@ -172,7 +187,8 @@ def loads(s, encoding="ASCII", errors="strict", buffers=None, - approved_imports={}): + approved_imports={}, + approved_module_patterns=None): if isinstance(s, str): raise TypeError("Can't load pickle from unicode string") file = io.BytesIO(s) @@ -181,4 +197,5 @@ def loads(s, buffers=buffers, encoding=encoding, errors=errors, - approved_imports=approved_imports).load() + approved_imports=approved_imports, + approved_module_patterns=approved_module_patterns).load() diff --git a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py index 96e8822612..9914913c2f 100644 --- a/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py +++ b/tests/unittest/_torch/ray_orchestrator/single_gpu/test_llm_update_weights.py @@ -1,3 +1,5 @@ +import base64 +import pickle from typing import Callable, List, Optional import pytest @@ -71,6 +73,40 @@ class HFModel: return ret + def get_weight_ipc_handles_serialized( + self, + cuda_device: Optional[List[int]] = None, + weight_filter: Optional[Callable[[str], bool]] = None, + ): + """ + Get base64-encoded serialized IPC handles for model weights. + + Args: + cuda_device: List of CUDA device indices to get weights from + weight_filter: Optional function that takes weight name and returns True if weight should be included + + Returns: + ret: Dictionary mapping device UUIDs to base64-encoded pickled handles + """ + ret = {} + device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device + + for device in device_list: + all_handles = [] + for item in self.all_weights[device]: + name, p = item + # Apply filter if provided + if weight_filter is not None and not weight_filter(name): + continue + handle = reduce_tensor(p) + all_handles.append((name, handle)) + + # Serialize with base64-encoded pickle + serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8") + ret[self.device_uuid[device]] = serialized + + return ret + def generate_batch_incremental( self, original_prompts: List[str], generated_token_ids_list: List[List[int]] ): @@ -153,11 +189,13 @@ def run_generate(llm, hf_model, prompts, sampling_params): return llm_logits, ref_logits +@pytest.mark.parametrize("use_serialized_handles", [True, False]) @pytest.mark.parametrize( "model_dir", ["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"], ) -def test_llm_update_weights(model_dir): +def test_llm_update_weights(model_dir, use_serialized_handles): + """Test LLM update_weights with both serialized and direct IPC handle formats.""" model_dir = str(llm_models_root() / model_dir) kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1) @@ -182,7 +220,11 @@ def test_llm_update_weights(model_dir): sampling_params = SamplingParams(temperature=0, return_generation_logits=True) - ipc_handles = hf_model.get_weight_ipc_handles([0]) + # Get IPC handles in either serialized or direct format + if use_serialized_handles: + ipc_handles = hf_model.get_weight_ipc_handles_serialized([0]) + else: + ipc_handles = hf_model.get_weight_ipc_handles([0]) llm._collective_rpc("update_weights", (ipc_handles,)) # Finalize the update weights