TensorRT-LLMs/tests/unittest/llmapi/test_serialization.py
Yibin Li 0f3bd7800e
[TRTLLM-4971]: Use safe deserialization in ParallelConfig (#4630)
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
2025-06-27 09:58:41 +08:00

121 lines
4.3 KiB
Python

import os
import tempfile
import torch
from tensorrt_llm import serialization
from tensorrt_llm.auto_parallel.config import AutoParallelConfig
from tensorrt_llm.auto_parallel.parallelization import ParallelConfig
from tensorrt_llm.auto_parallel.simplifier import GraphConfig, StageType
class TestClass:
def __init__(self, name: str):
self.name = name
def test_serialization_allowed_class():
obj = TestClass("test")
serialization.register_approved_class(TestClass)
module = TestClass.__module__
assert module in serialization.BASE_EXAMPLE_CLASSES
assert "TestClass" in serialization.BASE_EXAMPLE_CLASSES[module]
a = serialization.dumps(obj)
b = serialization.loads(a,
approved_imports=serialization.BASE_EXAMPLE_CLASSES)
assert type(obj) == type(b) and obj.name == b.name
def test_serialization_disallowed_class():
obj = TestClass("test")
a = serialization.dumps(obj)
excep = None
try:
serialization.loads(a, approved_imports={})
except Exception as e:
excep = e
print(excep)
assert isinstance(excep, ValueError) and str(
excep) == "Import llmapi.test_serialization | TestClass is not allowed"
def test_serialization_basic_object():
obj = {"test": "test"}
a = serialization.dumps(obj)
b = serialization.loads(a,
approved_imports=serialization.BASE_EXAMPLE_CLASSES)
assert obj == b
def test_serialization_complex_object_allowed_class():
obj = torch.tensor([1, 2, 3])
a = serialization.dumps(obj)
b = serialization.loads(a,
approved_imports=serialization.BASE_EXAMPLE_CLASSES)
assert torch.all(obj == b)
def test_serialization_complex_object_partially_allowed_class():
obj = torch.tensor([1, 2, 3])
a = serialization.dumps(obj)
excep = None
try:
b = serialization.loads(a,
approved_imports={
'torch._utils': ['_rebuild_tensor_v2'],
})
except Exception as e:
excep = e
assert isinstance(excep, ValueError) and str(
excep) == "Import torch.storage | _load_from_bytes is not allowed"
def test_serialization_complex_object_disallowed_class():
obj = torch.tensor([1, 2, 3])
a = serialization.dumps(obj)
excep = None
try:
serialization.loads(a)
except Exception as e:
excep = e
assert isinstance(excep, ValueError) and str(
excep) == "Import torch._utils | _rebuild_tensor_v2 is not allowed"
def test_parallel_config_serialization():
with tempfile.TemporaryDirectory() as tmpdir:
# Create a ParallelConfig instance with some test data
config = ParallelConfig()
config.version = "test_version"
config.network_hash = "test_hash"
config.auto_parallel_config = AutoParallelConfig(
world_size=2, gpus_per_node=2, cluster_key="test_cluster")
config.graph_config = GraphConfig(num_micro_batches=2,
num_blocks=3,
num_stages=2)
config.cost = 1.5
config.stage_type = StageType.START
config_path = os.path.join(tmpdir, "parallel_config.pkl")
config.save(config_path)
loaded_config = ParallelConfig.from_file(config_path)
# Verify the loaded config matches the original
assert loaded_config.version == config.version
assert loaded_config.network_hash == config.network_hash
assert loaded_config.auto_parallel_config.world_size == config.auto_parallel_config.world_size
assert loaded_config.auto_parallel_config.gpus_per_node == config.auto_parallel_config.gpus_per_node
assert loaded_config.auto_parallel_config.cluster_key == config.auto_parallel_config.cluster_key
assert loaded_config.graph_config.num_micro_batches == config.graph_config.num_micro_batches
assert loaded_config.graph_config.num_blocks == config.graph_config.num_blocks
assert loaded_config.graph_config.num_stages == config.graph_config.num_stages
assert loaded_config.cost == config.cost
assert loaded_config.stage_type == config.stage_type
if __name__ == "__main__":
test_serialization_allowed_class()
test_parallel_config_serialization()