mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[https://nvbugs/5775021] [fix] Replace pickle.load with restricted Unpickler (#10622)
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
This commit is contained in:
parent
ffd2ed51dd
commit
9116dfbacd
@ -1,9 +1,9 @@
|
||||
import base64
|
||||
import pickle # nosec B403
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import serialization
|
||||
from tensorrt_llm._ray_utils import control_action_decorator
|
||||
from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer
|
||||
from tensorrt_llm._torch.utils import get_device_uuid
|
||||
@ -62,8 +62,35 @@ class WorkerExtension:
|
||||
serialized_handles = ipc_handles[device_uuid]
|
||||
if isinstance(serialized_handles, str):
|
||||
# Data is base64-encoded pickled bytes - deserialize it
|
||||
# using restricted unpickler from tensorrt_llm.serialization
|
||||
logger.info("Deserializing base64-encoded weight handles")
|
||||
all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301
|
||||
decoded_data = base64.b64decode(serialized_handles)
|
||||
# Allow basic builtins and all torch modules
|
||||
approved_imports = {
|
||||
"builtins": [
|
||||
"list",
|
||||
"tuple",
|
||||
"str",
|
||||
"int",
|
||||
"float",
|
||||
"bool",
|
||||
"bytes",
|
||||
"dict",
|
||||
"NoneType",
|
||||
"type",
|
||||
],
|
||||
}
|
||||
all_handles = serialization.loads(
|
||||
decoded_data,
|
||||
approved_imports=approved_imports,
|
||||
approved_module_patterns=[r"^torch.*"],
|
||||
)
|
||||
|
||||
# Verify the result is a list as expected
|
||||
if not isinstance(all_handles, list):
|
||||
raise ValueError(
|
||||
f"Deserialized data must be a list, got {type(all_handles).__name__} instead"
|
||||
)
|
||||
else:
|
||||
# Data is already in the correct format (backward compatibility)
|
||||
all_handles = serialized_handles
|
||||
|
||||
@ -2,6 +2,7 @@ import io
|
||||
# pickle is not secure, but but this whole file is a wrapper to make it
|
||||
# possible to mitigate the primary risk of code injection via pickle.
|
||||
import pickle # nosec B403
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
# This is an example class (white list) to showcase how to guard serialization with approved classes.
|
||||
@ -126,19 +127,31 @@ def register_approved_class(obj):
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
|
||||
def __init__(self, *args, approved_imports={}, **kwargs):
|
||||
def __init__(self,
|
||||
*args,
|
||||
approved_imports={},
|
||||
approved_module_patterns=None,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.approved_imports = approved_imports
|
||||
self.approved_module_patterns = approved_module_patterns or []
|
||||
|
||||
# only import approved classes, this is the security boundary.
|
||||
def find_class(self, module, name):
|
||||
if name not in self.approved_imports.get(module, []):
|
||||
# If this is triggered when it shouldn't be, then the module
|
||||
# and class should be added to the approved_imports. If the class
|
||||
# is being used as part of a routine scenario, then it should be added
|
||||
# to the appropriate base classes above.
|
||||
raise ValueError(f"Import {module} | {name} is not allowed")
|
||||
return super().find_class(module, name)
|
||||
# Check exact match in approved_imports
|
||||
if name in self.approved_imports.get(module, []):
|
||||
return super().find_class(module, name)
|
||||
|
||||
# Check regex pattern match in approved_module_patterns
|
||||
for pattern in self.approved_module_patterns:
|
||||
if re.match(pattern, module):
|
||||
return super().find_class(module, name)
|
||||
|
||||
# If this is triggered when it shouldn't be, then the module
|
||||
# and class should be added to the approved_imports. If the class
|
||||
# is being used as part of a routine scenario, then it should be added
|
||||
# to the appropriate base classes above.
|
||||
raise ValueError(f"Import {module} | {name} is not allowed")
|
||||
|
||||
|
||||
# these are taken from the pickle module to allow for this to be a drop in replacement
|
||||
@ -156,13 +169,15 @@ def load(file,
|
||||
encoding="ASCII",
|
||||
errors="strict",
|
||||
buffers=None,
|
||||
approved_imports={}):
|
||||
approved_imports={},
|
||||
approved_module_patterns=None):
|
||||
return Unpickler(file,
|
||||
fix_imports=fix_imports,
|
||||
buffers=buffers,
|
||||
encoding=encoding,
|
||||
errors=errors,
|
||||
approved_imports=approved_imports).load()
|
||||
approved_imports=approved_imports,
|
||||
approved_module_patterns=approved_module_patterns).load()
|
||||
|
||||
|
||||
def loads(s,
|
||||
@ -172,7 +187,8 @@ def loads(s,
|
||||
encoding="ASCII",
|
||||
errors="strict",
|
||||
buffers=None,
|
||||
approved_imports={}):
|
||||
approved_imports={},
|
||||
approved_module_patterns=None):
|
||||
if isinstance(s, str):
|
||||
raise TypeError("Can't load pickle from unicode string")
|
||||
file = io.BytesIO(s)
|
||||
@ -181,4 +197,5 @@ def loads(s,
|
||||
buffers=buffers,
|
||||
encoding=encoding,
|
||||
errors=errors,
|
||||
approved_imports=approved_imports).load()
|
||||
approved_imports=approved_imports,
|
||||
approved_module_patterns=approved_module_patterns).load()
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import base64
|
||||
import pickle
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import pytest
|
||||
@ -71,6 +73,40 @@ class HFModel:
|
||||
|
||||
return ret
|
||||
|
||||
def get_weight_ipc_handles_serialized(
|
||||
self,
|
||||
cuda_device: Optional[List[int]] = None,
|
||||
weight_filter: Optional[Callable[[str], bool]] = None,
|
||||
):
|
||||
"""
|
||||
Get base64-encoded serialized IPC handles for model weights.
|
||||
|
||||
Args:
|
||||
cuda_device: List of CUDA device indices to get weights from
|
||||
weight_filter: Optional function that takes weight name and returns True if weight should be included
|
||||
|
||||
Returns:
|
||||
ret: Dictionary mapping device UUIDs to base64-encoded pickled handles
|
||||
"""
|
||||
ret = {}
|
||||
device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device
|
||||
|
||||
for device in device_list:
|
||||
all_handles = []
|
||||
for item in self.all_weights[device]:
|
||||
name, p = item
|
||||
# Apply filter if provided
|
||||
if weight_filter is not None and not weight_filter(name):
|
||||
continue
|
||||
handle = reduce_tensor(p)
|
||||
all_handles.append((name, handle))
|
||||
|
||||
# Serialize with base64-encoded pickle
|
||||
serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8")
|
||||
ret[self.device_uuid[device]] = serialized
|
||||
|
||||
return ret
|
||||
|
||||
def generate_batch_incremental(
|
||||
self, original_prompts: List[str], generated_token_ids_list: List[List[int]]
|
||||
):
|
||||
@ -153,11 +189,13 @@ def run_generate(llm, hf_model, prompts, sampling_params):
|
||||
return llm_logits, ref_logits
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_serialized_handles", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir",
|
||||
["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],
|
||||
)
|
||||
def test_llm_update_weights(model_dir):
|
||||
def test_llm_update_weights(model_dir, use_serialized_handles):
|
||||
"""Test LLM update_weights with both serialized and direct IPC handle formats."""
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
|
||||
@ -182,7 +220,11 @@ def test_llm_update_weights(model_dir):
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
|
||||
|
||||
ipc_handles = hf_model.get_weight_ipc_handles([0])
|
||||
# Get IPC handles in either serialized or direct format
|
||||
if use_serialized_handles:
|
||||
ipc_handles = hf_model.get_weight_ipc_handles_serialized([0])
|
||||
else:
|
||||
ipc_handles = hf_model.get_weight_ipc_handles([0])
|
||||
|
||||
llm._collective_rpc("update_weights", (ipc_handles,))
|
||||
# Finalize the update weights
|
||||
|
||||
Loading…
Reference in New Issue
Block a user