TensorRT-LLMs/tests/unittest/llmapi/test_serialization.py
coldwaterq 1cf0e672e7
fix: [nvbugs/5066257] serialization improvments (#3869)
* added a restricted pcikler and depickler in a sepparate serialization function.

Signed-off-by: coldwaterq@users.noreply.github.com <coldwaterq@users.noreply.github.com>

* updated IPC to remove approved classes, removed the serialization function because it didn't work for all objects that made debugging harder, added tests.

Signed-off-by: coldwaterq@users.noreply.github.com <coldwaterq@users.noreply.github.com>

* removed LLM arg and moved class registration to a serialization module function. Also added missing classes to approved list.

Signed-off-by: coldwaterq <coldwaterq@users.noreply.github.com>

* cleaned up a couple files to reduce conflicts with main.

Signed-off-by: coldwaterq <coldwaterq@users.noreply.github.com>

* fix unit tests

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

* reorder BASE_ZMQ_CLASSES list alphabetically

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

* fix tests and move LogitsProcessor registration to base class

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

* revert changes to import log of tensorrt_llm._torch.models

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

* added comments to explain why BASE_ZMQ_CLASSES has to be passed into spawned child processes

Signed-off-by: coldwaterq <coldwaterq@users.noreply.github.com>

* fix tests and move LogitsProcessor registration to base class

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

* additional comments for multiprocess approved list sync

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

* add dataclass from tests

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>

---------

Signed-off-by: coldwaterq@users.noreply.github.com <coldwaterq@users.noreply.github.com>
Signed-off-by: coldwaterq <coldwaterq@users.noreply.github.com>
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
Co-authored-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
2025-05-23 13:06:29 +08:00

79 lines
2.3 KiB
Python

import torch
import tensorrt_llm.executor.serialization as serialization
class TestClass:
def __init__(self, name: str):
self.name = name
def test_serialization_allowed_class():
obj = TestClass("test")
serialization.register_approved_ipc_class(TestClass)
module = TestClass.__module__
assert module in serialization.BASE_ZMQ_CLASSES
assert "TestClass" in serialization.BASE_ZMQ_CLASSES[module]
a = serialization.dumps(obj)
b = serialization.loads(a, approved_imports=serialization.BASE_ZMQ_CLASSES)
assert type(obj) == type(b) and obj.name == b.name
def test_serialization_disallowed_class():
obj = TestClass("test")
a = serialization.dumps(obj)
excep = None
try:
serialization.loads(a, approved_imports={})
except Exception as e:
excep = e
print(excep)
assert isinstance(excep, ValueError) and str(
excep) == "Import llmapi.test_serialization | TestClass is not allowed"
def test_serialization_basic_object():
obj = {"test": "test"}
a = serialization.dumps(obj)
b = serialization.loads(a, approved_imports=serialization.BASE_ZMQ_CLASSES)
assert obj == b
def test_serialization_complex_object_allowed_class():
obj = torch.tensor([1, 2, 3])
a = serialization.dumps(obj)
b = serialization.loads(a, approved_imports=serialization.BASE_ZMQ_CLASSES)
assert torch.all(obj == b)
def test_serialization_complex_object_partially_allowed_class():
obj = torch.tensor([1, 2, 3])
a = serialization.dumps(obj)
excep = None
try:
b = serialization.loads(a,
approved_imports={
'torch._utils': ['_rebuild_tensor_v2'],
})
except Exception as e:
excep = e
assert isinstance(excep, ValueError) and str(
excep) == "Import torch.storage | _load_from_bytes is not allowed"
def test_serialization_complex_object_disallowed_class():
obj = torch.tensor([1, 2, 3])
a = serialization.dumps(obj)
excep = None
try:
serialization.loads(a)
except Exception as e:
excep = e
assert isinstance(excep, ValueError) and str(
excep) == "Import torch._utils | _rebuild_tensor_v2 is not allowed"
if __name__ == "__main__":
test_serialization_allowed_class()