mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
import io
|
|
# pickle is not secure, but but this whole file is a wrapper to make it
|
|
# possible to mitigate the primary risk of code injection via pickle.
|
|
import pickle # nosec B403
|
|
|
|
# This is an example class (white list) to showcase how to guard serialization with approved classes.
|
|
# If a class is needed routinely it should be added into the whitelist. If it is only needed in a single instance
|
|
# the class can be added at runtime using register_approved_class.
|
|
BASE_EXAMPLE_CLASSES = {
|
|
"builtins": [
|
|
"Exception", "ValueError", "NotImplementedError", "AttributeError",
|
|
"AssertionError", "RuntimeError"
|
|
], # each Exception Error class needs to be added explicitly
|
|
"collections": ["OrderedDict"],
|
|
"torch._utils": ["_rebuild_tensor_v2"],
|
|
"torch.storage": ["_load_from_bytes"],
|
|
}
|
|
|
|
|
|
def _register_class(dict, obj):
|
|
name = getattr(obj, '__qualname__', None)
|
|
if name is None:
|
|
name = obj.__name__
|
|
module = pickle.whichmodule(obj, name)
|
|
if module not in dict.keys():
|
|
dict[module] = []
|
|
dict[module].append(name)
|
|
|
|
|
|
def register_approved_class(obj):
|
|
_register_class(BASE_EXAMPLE_CLASSES, obj)
|
|
|
|
|
|
class Unpickler(pickle.Unpickler):
|
|
|
|
def __init__(self, *args, approved_imports={}, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.approved_imports = approved_imports
|
|
|
|
# only import approved classes, this is the security boundary.
|
|
def find_class(self, module, name):
|
|
if name not in self.approved_imports.get(module, []):
|
|
# If this is triggered when it shouldn't be, then the module
|
|
# and class should be added to the approved_imports. If the class
|
|
# is being used as part of a routine scenario, then it should be added
|
|
# to the appropriate base classes above.
|
|
raise ValueError(f"Import {module} | {name} is not allowed")
|
|
return super().find_class(module, name)
|
|
|
|
|
|
# these are taken from the pickle module to allow for this to be a drop in replacement
|
|
# source: https://github.com/python/cpython/blob/3.13/Lib/pickle.py
|
|
# dump and dumps are just aliases because the serucity controls are on the deserialization
|
|
# side. However they are included here so that in the future if a more secure serialization
|
|
# soliton is identified, it can be added with less impact to the rest of the application.
|
|
dump = pickle.dump # nosec B301
|
|
dumps = pickle.dumps # nosec B301
|
|
|
|
|
|
def load(file,
|
|
*,
|
|
fix_imports=True,
|
|
encoding="ASCII",
|
|
errors="strict",
|
|
buffers=None,
|
|
approved_imports={}):
|
|
return Unpickler(file,
|
|
fix_imports=fix_imports,
|
|
buffers=buffers,
|
|
encoding=encoding,
|
|
errors=errors,
|
|
approved_imports=approved_imports).load()
|
|
|
|
|
|
def loads(s,
|
|
/,
|
|
*,
|
|
fix_imports=True,
|
|
encoding="ASCII",
|
|
errors="strict",
|
|
buffers=None,
|
|
approved_imports={}):
|
|
if isinstance(s, str):
|
|
raise TypeError("Can't load pickle from unicode string")
|
|
file = io.BytesIO(s)
|
|
return Unpickler(file,
|
|
fix_imports=fix_imports,
|
|
buffers=buffers,
|
|
encoding=encoding,
|
|
errors=errors,
|
|
approved_imports=approved_imports).load()
|