TensorRT-LLMs/tensorrt_llm/executor/serialization.py
ixlmar e055af1bc9
chore: improve disagg test failure detection (#4738)
Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com>
2025-06-15 01:28:26 +08:00

92 lines
3.3 KiB
Python

import io
# pickle is not secure, but but this whole file is a wrapper to make it
# possible to mitigate the primary risk of code injection via pickle.
import pickle # nosec B403
# This is an example class (white list) to showcase how to guard serialization with approved classes.
# If a class is needed routinely it should be added into the whitelist. If it is only needed in a single instance
# the class can be added at runtime using register_approved_class.
BASE_EXAMPLE_CLASSES = {
"builtins": [
"Exception", "ValueError", "NotImplementedError", "AttributeError",
"AssertionError", "RuntimeError"
], # each Exception Error class needs to be added explicitly
"collections": ["OrderedDict"],
"torch._utils": ["_rebuild_tensor_v2"],
"torch.storage": ["_load_from_bytes"],
}
def _register_class(dict, obj):
name = getattr(obj, '__qualname__', None)
if name is None:
name = obj.__name__
module = pickle.whichmodule(obj, name)
if module not in dict.keys():
dict[module] = []
dict[module].append(name)
def register_approved_class(obj):
_register_class(BASE_EXAMPLE_CLASSES, obj)
class Unpickler(pickle.Unpickler):
def __init__(self, *args, approved_imports={}, **kwargs):
super().__init__(*args, **kwargs)
self.approved_imports = approved_imports
# only import approved classes, this is the security boundary.
def find_class(self, module, name):
if name not in self.approved_imports.get(module, []):
# If this is triggered when it shouldn't be, then the module
# and class should be added to the approved_imports. If the class
# is being used as part of a routine scenario, then it should be added
# to the appropriate base classes above.
raise ValueError(f"Import {module} | {name} is not allowed")
return super().find_class(module, name)
# these are taken from the pickle module to allow for this to be a drop in replacement
# source: https://github.com/python/cpython/blob/3.13/Lib/pickle.py
# dump and dumps are just aliases because the serucity controls are on the deserialization
# side. However they are included here so that in the future if a more secure serialization
# soliton is identified, it can be added with less impact to the rest of the application.
dump = pickle.dump # nosec B301
dumps = pickle.dumps # nosec B301
def load(file,
*,
fix_imports=True,
encoding="ASCII",
errors="strict",
buffers=None,
approved_imports={}):
return Unpickler(file,
fix_imports=fix_imports,
buffers=buffers,
encoding=encoding,
errors=errors,
approved_imports=approved_imports).load()
def loads(s,
/,
*,
fix_imports=True,
encoding="ASCII",
errors="strict",
buffers=None,
approved_imports={}):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
return Unpickler(file,
fix_imports=fix_imports,
buffers=buffers,
encoding=encoding,
errors=errors,
approved_imports=approved_imports).load()