[https://nvbugs/5775021] [fix] Replace pickle.load with restricted Unpickler (#10622)

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
This commit is contained in:
Yibin Li 2026-01-20 19:42:54 -08:00 committed by GitHub
parent ffd2ed51dd
commit 9116dfbacd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 102 additions and 16 deletions

View File

@ -1,9 +1,9 @@
import base64
import pickle # nosec B403
from typing import Optional
import torch
from tensorrt_llm import serialization
from tensorrt_llm._ray_utils import control_action_decorator
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer
from tensorrt_llm._torch.utils import get_device_uuid
@ -62,8 +62,35 @@ class WorkerExtension:
serialized_handles = ipc_handles[device_uuid]
if isinstance(serialized_handles, str):
# Data is base64-encoded pickled bytes - deserialize it
# using restricted unpickler from tensorrt_llm.serialization
logger.info("Deserializing base64-encoded weight handles")
all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301
decoded_data = base64.b64decode(serialized_handles)
# Allow basic builtins and all torch modules
approved_imports = {
"builtins": [
"list",
"tuple",
"str",
"int",
"float",
"bool",
"bytes",
"dict",
"NoneType",
"type",
],
}
all_handles = serialization.loads(
decoded_data,
approved_imports=approved_imports,
approved_module_patterns=[r"^torch.*"],
)
# Verify the result is a list as expected
if not isinstance(all_handles, list):
raise ValueError(
f"Deserialized data must be a list, got {type(all_handles).__name__} instead"
)
else:
# Data is already in the correct format (backward compatibility)
all_handles = serialized_handles

View File

@ -2,6 +2,7 @@ import io
# pickle is not secure, but but this whole file is a wrapper to make it
# possible to mitigate the primary risk of code injection via pickle.
import pickle # nosec B403
import re
from functools import partial
# This is an example class (white list) to showcase how to guard serialization with approved classes.
@ -126,19 +127,31 @@ def register_approved_class(obj):
class Unpickler(pickle.Unpickler):
def __init__(self, *args, approved_imports={}, **kwargs):
def __init__(self,
*args,
approved_imports={},
approved_module_patterns=None,
**kwargs):
super().__init__(*args, **kwargs)
self.approved_imports = approved_imports
self.approved_module_patterns = approved_module_patterns or []
# only import approved classes, this is the security boundary.
def find_class(self, module, name):
if name not in self.approved_imports.get(module, []):
# If this is triggered when it shouldn't be, then the module
# and class should be added to the approved_imports. If the class
# is being used as part of a routine scenario, then it should be added
# to the appropriate base classes above.
raise ValueError(f"Import {module} | {name} is not allowed")
return super().find_class(module, name)
# Check exact match in approved_imports
if name in self.approved_imports.get(module, []):
return super().find_class(module, name)
# Check regex pattern match in approved_module_patterns
for pattern in self.approved_module_patterns:
if re.match(pattern, module):
return super().find_class(module, name)
# If this is triggered when it shouldn't be, then the module
# and class should be added to the approved_imports. If the class
# is being used as part of a routine scenario, then it should be added
# to the appropriate base classes above.
raise ValueError(f"Import {module} | {name} is not allowed")
# these are taken from the pickle module to allow for this to be a drop in replacement
@ -156,13 +169,15 @@ def load(file,
encoding="ASCII",
errors="strict",
buffers=None,
approved_imports={}):
approved_imports={},
approved_module_patterns=None):
return Unpickler(file,
fix_imports=fix_imports,
buffers=buffers,
encoding=encoding,
errors=errors,
approved_imports=approved_imports).load()
approved_imports=approved_imports,
approved_module_patterns=approved_module_patterns).load()
def loads(s,
@ -172,7 +187,8 @@ def loads(s,
encoding="ASCII",
errors="strict",
buffers=None,
approved_imports={}):
approved_imports={},
approved_module_patterns=None):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
@ -181,4 +197,5 @@ def loads(s,
buffers=buffers,
encoding=encoding,
errors=errors,
approved_imports=approved_imports).load()
approved_imports=approved_imports,
approved_module_patterns=approved_module_patterns).load()

View File

@ -1,3 +1,5 @@
import base64
import pickle
from typing import Callable, List, Optional
import pytest
@ -71,6 +73,40 @@ class HFModel:
return ret
def get_weight_ipc_handles_serialized(
self,
cuda_device: Optional[List[int]] = None,
weight_filter: Optional[Callable[[str], bool]] = None,
):
"""
Get base64-encoded serialized IPC handles for model weights.
Args:
cuda_device: List of CUDA device indices to get weights from
weight_filter: Optional function that takes weight name and returns True if weight should be included
Returns:
ret: Dictionary mapping device UUIDs to base64-encoded pickled handles
"""
ret = {}
device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device
for device in device_list:
all_handles = []
for item in self.all_weights[device]:
name, p = item
# Apply filter if provided
if weight_filter is not None and not weight_filter(name):
continue
handle = reduce_tensor(p)
all_handles.append((name, handle))
# Serialize with base64-encoded pickle
serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8")
ret[self.device_uuid[device]] = serialized
return ret
def generate_batch_incremental(
self, original_prompts: List[str], generated_token_ids_list: List[List[int]]
):
@ -153,11 +189,13 @@ def run_generate(llm, hf_model, prompts, sampling_params):
return llm_logits, ref_logits
@pytest.mark.parametrize("use_serialized_handles", [True, False])
@pytest.mark.parametrize(
"model_dir",
["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],
)
def test_llm_update_weights(model_dir):
def test_llm_update_weights(model_dir, use_serialized_handles):
"""Test LLM update_weights with both serialized and direct IPC handle formats."""
model_dir = str(llm_models_root() / model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
@ -182,7 +220,11 @@ def test_llm_update_weights(model_dir):
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
ipc_handles = hf_model.get_weight_ipc_handles([0])
# Get IPC handles in either serialized or direct format
if use_serialized_handles:
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0])
else:
ipc_handles = hf_model.get_weight_ipc_handles([0])
llm._collective_rpc("update_weights", (ipc_handles,))
# Finalize the update weights