mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* added a restricted pcikler and depickler in a sepparate serialization function. Signed-off-by: coldwaterq@users.noreply.github.com <coldwaterq@users.noreply.github.com> * updated IPC to remove approved classes, removed the serialization function because it didn't work for all objects that made debugging harder, added tests. Signed-off-by: coldwaterq@users.noreply.github.com <coldwaterq@users.noreply.github.com> * removed LLM arg and moved class registration to a serialization module function. Also added missing classes to approved list. Signed-off-by: coldwaterq <coldwaterq@users.noreply.github.com> * cleaned up a couple files to reduce conflicts with main. Signed-off-by: coldwaterq <coldwaterq@users.noreply.github.com> * fix unit tests Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> * reorder BASE_ZMQ_CLASSES list alphabetically Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> * fix tests and move LogitsProcessor registration to base class Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> * revert changes to import log of tensorrt_llm._torch.models Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> * added comments to explain why BASE_ZMQ_CLASSES has to be passed into spawned child processes Signed-off-by: coldwaterq <coldwaterq@users.noreply.github.com> * fix tests and move LogitsProcessor registration to base class Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> * additional comments for multiprocess approved list sync Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> * add dataclass from tests Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> --------- Signed-off-by: coldwaterq@users.noreply.github.com <coldwaterq@users.noreply.github.com> Signed-off-by: coldwaterq <coldwaterq@users.noreply.github.com> Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com> Co-authored-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
import torch
|
|
|
|
import tensorrt_llm.executor.serialization as serialization
|
|
|
|
|
|
class TestClass:
|
|
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
|
|
|
|
def test_serialization_allowed_class():
|
|
obj = TestClass("test")
|
|
serialization.register_approved_ipc_class(TestClass)
|
|
module = TestClass.__module__
|
|
assert module in serialization.BASE_ZMQ_CLASSES
|
|
assert "TestClass" in serialization.BASE_ZMQ_CLASSES[module]
|
|
a = serialization.dumps(obj)
|
|
b = serialization.loads(a, approved_imports=serialization.BASE_ZMQ_CLASSES)
|
|
assert type(obj) == type(b) and obj.name == b.name
|
|
|
|
|
|
def test_serialization_disallowed_class():
|
|
obj = TestClass("test")
|
|
a = serialization.dumps(obj)
|
|
excep = None
|
|
try:
|
|
serialization.loads(a, approved_imports={})
|
|
except Exception as e:
|
|
excep = e
|
|
print(excep)
|
|
assert isinstance(excep, ValueError) and str(
|
|
excep) == "Import llmapi.test_serialization | TestClass is not allowed"
|
|
|
|
|
|
def test_serialization_basic_object():
|
|
obj = {"test": "test"}
|
|
a = serialization.dumps(obj)
|
|
b = serialization.loads(a, approved_imports=serialization.BASE_ZMQ_CLASSES)
|
|
assert obj == b
|
|
|
|
|
|
def test_serialization_complex_object_allowed_class():
|
|
obj = torch.tensor([1, 2, 3])
|
|
a = serialization.dumps(obj)
|
|
b = serialization.loads(a, approved_imports=serialization.BASE_ZMQ_CLASSES)
|
|
assert torch.all(obj == b)
|
|
|
|
|
|
def test_serialization_complex_object_partially_allowed_class():
|
|
obj = torch.tensor([1, 2, 3])
|
|
a = serialization.dumps(obj)
|
|
excep = None
|
|
try:
|
|
b = serialization.loads(a,
|
|
approved_imports={
|
|
'torch._utils': ['_rebuild_tensor_v2'],
|
|
})
|
|
except Exception as e:
|
|
excep = e
|
|
assert isinstance(excep, ValueError) and str(
|
|
excep) == "Import torch.storage | _load_from_bytes is not allowed"
|
|
|
|
|
|
def test_serialization_complex_object_disallowed_class():
|
|
obj = torch.tensor([1, 2, 3])
|
|
a = serialization.dumps(obj)
|
|
excep = None
|
|
try:
|
|
serialization.loads(a)
|
|
except Exception as e:
|
|
excep = e
|
|
assert isinstance(excep, ValueError) and str(
|
|
excep) == "Import torch._utils | _rebuild_tensor_v2 is not allowed"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_serialization_allowed_class()
|